diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000..2b65f6fe3cc80 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/.rat-excludes b/.rat-excludes index ae9745673c87d..20e3372464386 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,5 +1,6 @@ target .gitignore +.gitattributes .project .classpath .mima-excludes diff --git a/LICENSE b/LICENSE index f1732fb47afc0..3c667bf45059a 100644 --- a/LICENSE +++ b/LICENSE @@ -754,7 +754,7 @@ SUCH DAMAGE. ======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): ======================================================================== Copyright (C) 2008 The Android Open Source Project @@ -771,6 +771,25 @@ See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For LimitedInputStream + (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): +======================================================================== +Copyright (C) 2007 The Guava Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/README.md b/README.md index 9916ac7b1ae8e..8d57d50da96c9 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html). +guide, on the [project web page](http://spark.apache.org/documentation.html) +and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). This README file only contains basic setup instructions. ## Building Spark diff --git a/assembly/pom.xml b/assembly/pom.xml index 11d4bea9361ab..31a01e4d8e1de 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -146,10 +146,6 @@ com/google/common/base/Present* - - org.apache.commons.math3 - org.spark-project.commons.math3 - @@ -201,12 +197,6 @@ spark-hive_${scala.binary.version} ${project.version} - - - - - hive-0.12.0 - org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 3cd0579aea8d3..a4c099fb45b14 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -1,117 +1,117 @@ -@echo off - -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -rem script and the ExecutorRunner in standalone cluster mode. - -rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting -rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we -rem need to set it here because we use !datanucleus_jars! below. -if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion -setlocal enabledelayedexpansion -:skip_delayed_expansion - -set SCALA_VERSION=2.10 - -rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" - -rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% - -if not "x%SPARK_CONF_DIR%"=="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% -) else ( - set CLASSPATH=%CLASSPATH%;%FWDIR%conf -) - -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) else ( - for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) - -set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% - -rem When Hive support is needed, Datanucleus jars must be included on the classpath. -rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -rem built with Hive, so look for them there. -if exist "%FWDIR%RELEASE" ( - set datanucleus_dir=%FWDIR%lib -) else ( - set datanucleus_dir=%FWDIR%lib_managed\jars -) -set "datanucleus_jars=" -for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( - set datanucleus_jars=!datanucleus_jars!;%%d -) -set CLASSPATH=%CLASSPATH%;%datanucleus_jars% - -set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes - -set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes - -if "x%SPARK_TESTING%"=="x1" ( - rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH - rem so that local compilation takes precedence over assembled jar - set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% -) - -rem Add hadoop conf dir - else FileSystem.*, etc fail -rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -rem the configurtion files. -if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir - set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% -:no_hadoop_conf_dir - -if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir - set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% -:no_yarn_conf_dir - -rem A bit of a hack to allow calling this script within run2.cmd without seeing output -if "%DONT_PRINT_CLASSPATH%"=="1" goto exit - -echo %CLASSPATH% - -:exit +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" +rem script and the ExecutorRunner in standalone cluster mode. + +rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting +rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we +rem need to set it here because we use !datanucleus_jars! below. +if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion +setlocal enabledelayedexpansion +:skip_delayed_expansion + +set SCALA_VERSION=2.10 + +rem Figure out where the Spark framework is installed +set FWDIR=%~dp0..\ + +rem Load environment variables from conf\spark-env.cmd, if it exists +if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" + +rem Build up classpath +set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% + +if not "x%SPARK_CONF_DIR%"=="x" ( + set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% +) else ( + set CLASSPATH=%CLASSPATH%;%FWDIR%conf +) + +if exist "%FWDIR%RELEASE" ( + for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( + set ASSEMBLY_JAR=%%d + ) +) else ( + for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( + set ASSEMBLY_JAR=%%d + ) +) + +set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% + +rem When Hive support is needed, Datanucleus jars must be included on the classpath. +rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. +rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is +rem built with Hive, so look for them there. +if exist "%FWDIR%RELEASE" ( + set datanucleus_dir=%FWDIR%lib +) else ( + set datanucleus_dir=%FWDIR%lib_managed\jars +) +set "datanucleus_jars=" +for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( + set datanucleus_jars=!datanucleus_jars!;%%d +) +set CLASSPATH=%CLASSPATH%;%datanucleus_jars% + +set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes + +set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes + +if "x%SPARK_TESTING%"=="x1" ( + rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH + rem so that local compilation takes precedence over assembled jar + set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% +) + +rem Add hadoop conf dir - else FileSystem.*, etc fail +rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +rem the configurtion files. +if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir + set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% +:no_hadoop_conf_dir + +if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir + set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% +:no_yarn_conf_dir + +rem A bit of a hack to allow calling this script within run2.cmd without seeing output +if "%DONT_PRINT_CLASSPATH%"=="1" goto exit + +echo %CLASSPATH% + +:exit diff --git a/core/pom.xml b/core/pom.xml index 8020a2daf81ec..41296e0eca330 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -46,7 +46,12 @@ org.apache.spark - network + spark-network-common_2.10 + ${project.version} + + + org.apache.spark + spark-network-shuffle_2.10 ${project.version} diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js new file mode 100644 index 0000000000000..badd85ed48c82 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Register functions to show/hide columns based on checkboxes. These need + * to be registered after the page loads. */ +$(function() { + $("span.expand-additional-metrics").click(function(){ + // Expand the list of additional metrics. + var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); + $(additionalMetricsDiv).toggleClass('collapsed'); + + // Switch the class of the arrow from open to closed. + $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); + $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + // If clicking caused the metrics to expand, automatically check all options for additional + // metrics (don't trigger a click when collapsing metrics, because it leads to weird + // toggling behavior). + if (!$(additionalMetricsDiv).hasClass('collapsed')) { + $(this).parent().find('input:checkbox:not(:checked)').trigger('click'); + } + }); + + $("input:checkbox:not(:checked)").each(function() { + var column = "table ." + $(this).attr("name"); + $(column).hide(); + }); + // Stripe table rows after rows have been hidden to ensure correct striping. + stripeTables(); + + $("input:checkbox").click(function() { + var column = "table ." + $(this).attr("name"); + $(column).toggle(); + stripeTables(); + }); + + // Trigger a click on the checkbox if a user clicks the label next to it. + $("span.additional-metric-title").click(function() { + $(this).parent().find('input:checkbox').trigger('click'); + }); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js new file mode 100644 index 0000000000000..6bb03015abb51 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Adds background colors to stripe table rows. This is necessary (instead of using css or the + * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */ +function stripeTables() { + $("table.table-striped-custom").each(function() { + $(this).find("tr:not(:hidden)").each(function (index) { + if (index % 2 == 1) { + $(this).css("background-color", "#f9f9f9"); + } else { + $(this).css("background-color", "#ffffff"); + } + }); + }); +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 152bde5f6994f..db57712c83503 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -120,7 +120,51 @@ pre { border: none; } +.stacktrace-details { + max-height: 300px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stacktrace-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} + +span.expand-additional-metrics { + cursor: pointer; +} + +span.additional-metric-title { + cursor: pointer; +} + +.additional-metrics.collapsed { + display: none; +} + .tooltip { font-weight: normal; } +.arrow-open { + width: 0; + height: 0; + border-left: 5px solid transparent; + border-right: 5px solid transparent; + border-top: 5px solid black; + float: left; + margin-top: 6px; +} + +.arrow-closed { + width: 0; + height: 0; + border-top: 5px solid transparent; + border-bottom: 5px solid transparent; + border-left: 5px solid black; + display: inline-block; +} diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index b2cf022baf29f..ef93009a074e7 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Lower and upper bounds on the number of executors. These are required. private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) - verifyBounds() // How long there must be backlogged tasks for before an addition is triggered private val schedulerBacklogTimeout = conf.getLong( @@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) // How long an executor must be idle for before it is removed - private val removeThresholdSeconds = conf.getLong( + private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) + // During testing, the methods to actually kill and add executors are mocked out + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + validateSettings() + // Number of executors to add in the next round private var numExecutorsToAdd = 1 @@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Polling loop interval (ms) private val intervalMillis: Long = 100 - // Whether we are testing this class. This should only be used internally. - private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) - // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock /** - * Verify that the lower and upper bounds on the number of executors are valid. + * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. */ - private def verifyBounds(): Unit = { + private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") } @@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } + if (schedulerBacklogTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") + } + if (sustainedSchedulerBacklogTimeout <= 0) { + throw new SparkException( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") + } + if (executorIdleTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + } + // Require external shuffle service for dynamic allocation + // Otherwise, we may lose shuffle files when killing executors + if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } } /** @@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging val removeRequestAcknowledged = testing || sc.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging private def onExecutorIdle(executorId: String): Unit = synchronized { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)") - removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 } } @@ -419,7 +436,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { val executorId = blockManagerAdded.blockManagerId.executorId - if (executorId != "") { + if (executorId != SparkContext.DRIVER_IDENTIFIER) { allocationManager.onExecutorAdded(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4cb0bd4142435..7d96962c4acd7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } else { + logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) } @@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr new ConcurrentHashMap[Int, Array[MapStatus]] } -private[spark] object MapOutputTracker { +private[spark] object MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will @@ -381,6 +382,7 @@ private[spark] object MapOutputTracker { statuses.map { status => if (status == null) { + logError("Missing an output location for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377e..dbff9d12b5ad7 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + // Default SecurityManager only has a single secret key, so ignore appId. + override def getSaslUser(appId: String): String = getSaslUser() + override def getSecretKey(appId: String): String = getSecretKey() } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ad0a9017afead..4c6c86c7bad78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 73668e83bbb1d..03ea672c813d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,9 +21,8 @@ import scala.language.implicitConversions import java.io._ import java.net.URI -import java.util.Arrays +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.generic.Growable @@ -41,7 +40,8 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.WholeTextFileInputFormat +import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.JobProgressListener -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. @@ -361,6 +363,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { @@ -533,6 +558,73 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { minPartitions).setName(path) } + + /** + * :: Experimental :: + * + * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file + * (useful for binary data) + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred; very large files may cause bad performance. + */ + @Experimental + def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): + RDD[(String, PortableDataStream)] = { + val job = new NewHadoopJob(hadoopConfiguration) + NewFileInputFormat.addInputPath(job, new Path(path)) + val updateConf = job.getConfiguration + new BinaryFileRDD( + this, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + updateConf, + minPartitions).setName(path) + } + + /** + * :: Experimental :: + * + * Load data from a flat binary file, assuming the length of each record is constant. + * + * @param path Directory to the input data files + * @param recordLength The length at which to split the records + * @return An RDD of data with values, represented as byte arrays + */ + @Experimental + def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) + : RDD[Array[Byte]] = { + conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) + val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, + classOf[FixedLengthBinaryInputFormat], + classOf[LongWritable], + classOf[BytesWritable], + conf=conf) + val data = br.map{ case (k, v) => v.getBytes} + data + } + /** * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), @@ -1333,6 +1425,8 @@ object SparkContext extends Logging { private[spark] val SPARK_UNKNOWN_USER = "" + private[spark] val DRIVER_IDENTIFIER = "" + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6a6dfda363974..e7454beddbfd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.netty.{NettyBlockTransferService} +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -156,7 +156,7 @@ object SparkEnv extends Logging { assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") val port = conf.get("spark.driver.port").toInt - create(conf, "", hostname, port, true, isLocal, listenerBus) + create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus) } /** @@ -274,9 +274,9 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "nio").toLowerCase match { + conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf) + new NettyBlockTransferService(conf, securityManager) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -285,8 +285,9 @@ object SparkEnv extends Logging { "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 376e69cd997d5..40237596570de 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD /** diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index a954fcc0c31fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index 7c2afb364661f..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), UTF_8) - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), UTF_8).toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8f0c5e78416c2..af5fd8e0ac00c 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -69,11 +69,13 @@ case class FetchFailed( bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, mapId: Int, - reduceId: Int) + reduceId: Int, + message: String) extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + + s"message=\n$message\n)" } } @@ -81,15 +83,48 @@ case class FetchFailed( * :: DeveloperApi :: * Task failed due to a runtime exception. This is the most common failure case and also captures * user program exceptions. + * + * `stackTrace` contains the stack trace of the exception itself. It still exists for backward + * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to + * create `ExceptionFailure` as it will handle the backward compatibility properly. + * + * `fullStackTrace` is a better representation of the stack trace because it contains the whole + * stack trace including the exception and its causes */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], + fullStackTrace: String, metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) + + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + } + + override def toErrorString: String = + if (fullStackTrace == null) { + // fullStackTrace is added in 1.2.0 + // If fullStackTrace is null, use the old error string for backward compatibility + exceptionString(className, description, stackTrace) + } else { + fullStackTrace + } + + /** + * Return a nice string representation of the exception, including the stack trace. + * Note: It does not include the exception's causes, and is only used for backward compatibility. + */ + private def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } } /** @@ -117,8 +152,8 @@ case object TaskKilled extends TaskFailedReason { * the task crashed the JVM. */ @DeveloperApi -case object ExecutorLostFailure extends TaskFailedReason { - override def toErrorString: String = "ExecutorLostFailure (executor lost)" +case class ExecutorLostFailure(execId: String) extends TaskFailedReason { + override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index efb8978f7ce12..5a8e5bb1f721a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -493,9 +493,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD as defined by + * Returns the top k (largest) elements from this RDD as defined by * the specified Comparator[T]. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -507,9 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD using the + * Returns the top k (largest) elements from this RDD using the * natural ordering for T. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def top(num: Int): JList[T] = { @@ -518,9 +518,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD as defined by + * Returns the first k (smallest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -552,9 +552,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD using the + * Returns the first k (smallest) elements from this RDD using the * natural ordering for T while maintain the order. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def takeOrdered(num: Int): JList[T] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 0565adf4d4ead..5c6e8d32c5c8a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -28,11 +28,13 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration +import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} @@ -202,6 +204,8 @@ class JavaSparkContext(val sc: SparkContext) def textFile(path: String, minPartitions: Int): JavaRDD[String] = sc.textFile(path, minPartitions) + + /** * Read a directory of text files from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI. Each file is read as a single record and returned in a @@ -245,6 +249,84 @@ class JavaSparkContext(val sc: SparkContext) def wholeTextFiles(path: String): JavaPairRDD[String, String] = new JavaPairRDD(sc.wholeTextFiles(path)) + /** + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred; very large files but may cause bad performance. + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + */ + def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path, minPartitions)) + + /** + * :: Experimental :: + * + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred; very large files but may cause bad performance. + */ + @Experimental + def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) + + /** + * :: Experimental :: + * + * Load data from a flat binary file, assuming the length of each record is constant. + * + * @param path Directory to the input data files + * @return An RDD of data with values, represented as byte arrays + */ + @Experimental + def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { + new JavaRDD(sc.binaryRecords(path, recordLength)) + } + /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 49dc95f349eac..5ba66178e2b78 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -61,8 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]], - batchSize: Int) extends Converter[Any, Any] { + conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter( map.put(convertWritable(k), convertWritable(v)) } map - case w: Writable => - if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w + case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 61b125ef7c6c1..45beb8fc8c925 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,13 +21,13 @@ import java.io._ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import org.apache.spark.input.PortableDataStream + import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 -import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -397,22 +397,33 @@ private[spark] object PythonRDD extends Logging { newIter.asInstanceOf[Iterator[String]].foreach { str => writeUTF(str, dataOut) } - case pair: Tuple2[_, _] => - pair._1 match { - case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - } - case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) - } - case other => - throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + case stream: PortableDataStream => + newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, stream: PortableDataStream) => + newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { + case (key, stream) => + writeUTF(key, dataOut) + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, value: String) => + newIter.asInstanceOf[Iterator[(String, String)]].foreach { + case (key, value) => + writeUTF(key, dataOut) + writeUTF(value, dataOut) + } + case (key: Array[Byte], value: Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + dataOut.writeInt(key.length) + dataOut.write(key) + dataOut.writeInt(value.length) + dataOut.write(value) } case other => throw new SparkException("Unexpected element type " + first.getClass) @@ -442,7 +453,7 @@ private[spark] object PythonRDD extends Logging { val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -468,7 +479,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -494,7 +505,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -537,7 +548,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -563,7 +574,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -746,104 +757,6 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - SerDeUtil.initialize() - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - SerDeUtil.initialize() - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index ebdc3533e0992..a4153aaa926f8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging { } initialize() + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler val kt = Try { @@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging { */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => - val pickle = new Pickler val cleaned = iter.map { case (k, v) => val key = if (keyFailed) k.toString else k val value = if (valueFailed) v.toString else v Array[Any](key, value) } - if (batchSize > 1) { - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + if (batchSize == 0) { + new AutoBatchedPickler(cleaned) } else { - cleaned.map(pickle.dumps(_)) + val pickle = new Pickler + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) } } } @@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging { /** * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. */ - def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } - pyRDD.mapPartitions { iter => - initialize() - val unpickle = new Unpickler - val unpickled = - if (batchSerialized) { - iter.flatMap { batch => - unpickle.loads(batch) match { - case objs: java.util.List[_] => collectionAsScalaIterable(objs) - case other => throw new SparkException( - s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") - } - } - } else { - iter.map(unpickle.loads(_)) - } - unpickled.map { - case obj if isPair(obj) => - // we only accept (K, V) - val arr = obj.asInstanceOf[Array[_]] - (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) - case other => throw new SparkException( - s"RDD element of type ${other.getClass.getName} cannot be used") - } + + val rdd = pythonToJava(pyRDD, batched).rdd + rdd.first match { + case obj if isPair(obj) => + // we only accept (K, V) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") + } + rdd.map { obj => + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) } } - } - diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index e9ca9166eb4d6..c0cbd28a845be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator { // Create test data for arbitrary custom writable TestWritable val testClass = Seq( - ("1", TestWritable("test1", 123, 54.0)), - ("2", TestWritable("test2", 456, 8762.3)), - ("1", TestWritable("test3", 123, 423.1)), - ("3", TestWritable("test56", 456, 423.5)), - ("2", TestWritable("test2", 123, 5435.2)) + ("1", TestWritable("test1", 1, 1.0)), + ("2", TestWritable("test2", 2, 2.3)), + ("3", TestWritable("test3", 3, 3.1)), + ("5", TestWritable("test56", 5, 5.5)), + ("4", TestWritable("test4", 4, 4.2)) ) val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } rdd.saveAsNewAPIHadoopFile(classPath, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e28eaad8a5180..60ee115e393ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration @@ -133,14 +134,9 @@ class SparkHadoopUtil extends Logging { */ private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) : Option[() => Long] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) try { - val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) - val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") - val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead") + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() Some(() => f() - baselineBytesRead) @@ -151,6 +147,42 @@ class SparkHadoopUtil extends Logging { } } } + + /** + * Returns a function that can be called to find Hadoop FileSystem bytes written. If + * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes written on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + try { + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") + val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesWritten = f() + Some(() => f() - baselineBytesWritten) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) + None + } + } + } + + private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { + val qualifiedPath = path.getFileSystem(conf).makeQualified(path) + val scheme = qualifiedPath.toUri().getScheme() + val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + } + + private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { + val statisticsDataClass = + Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + statisticsDataClass.getDeclaredMethod(methodName) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f97bf67fa5a3b..b43e68e40f791 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,8 +158,9 @@ object SparkSubmit { args.files = mergeFileLists(args.files, args.primaryResource) } args.files = mergeFileLists(args.files, args.pyFiles) - // Format python file paths properly before adding them to the PYTHONPATH - sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } } // Special flag to avoid deprecation warnings at the client @@ -273,15 +274,32 @@ object SparkSubmit { } } - // Properties given with --conf are superceded by other options, but take precedence over - // properties in the defaults file. + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Read from default spark properties, if any - for ((k, v) <- args.defaultSparkProperties) { - sysProps.getOrElseUpdate(k, v) + // Resolve paths in certain spark properties + val pathConfigs = Seq( + "spark.jars", + "spark.files", + "spark.yarn.jar", + "spark.yarn.dist.files", + "spark.yarn.dist.archives") + pathConfigs.foreach { config => + // Replace old URIs with resolved URIs, if they exist + sysProps.get(config).foreach { oldValue => + sysProps(config) = Utils.resolveURIs(oldValue) + } + } + + // Resolve and format python file paths properly before adding them to the PYTHONPATH. + // The resolving part is redundant in the case of --py-files, but necessary if the user + // explicitly sets `spark.submit.pyFiles` in his/her default properties file. + sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + val resolvedPyFiles = Utils.resolveURIs(pyFiles) + val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + sysProps("spark.submit.pyFiles") = formattedPyFiles } (childArgs, childClasspath, sysProps, childMainClass) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 72a452e0aefb5..f0e9ee67f6a67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy import java.util.jar.JarFile -import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.util.Utils @@ -72,39 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St defaultProperties } - // Respect SPARK_*_MEMORY for cluster mode - driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull - executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull - + // Set parameters from command line arguments parseOpts(args.toList) - mergeSparkProperties() + // Populate `sparkProperties` map from properties file + mergeDefaultSparkProperties() + // Use `sparkProperties` map along with env vars to fill in any missing parameters + loadEnvironmentArguments() + checkRequiredArguments() /** - * Fill in any undefined values based on the default properties file or options passed in through - * the '--conf' flag. + * Merge values from the default properties file with those specified through --conf. + * When this is called, `sparkProperties` is already filled with configs from the latter. */ - private def mergeSparkProperties(): Unit = { + private def mergeDefaultSparkProperties(): Unit = { // Use common defaults file, if not specified by user propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) + // Honor --conf before the defaults file + defaultSparkProperties.foreach { case (k, v) => + if (!sparkProperties.contains(k)) { + sparkProperties(k) = v + } + } + } - val properties = HashMap[String, String]() - properties.putAll(defaultSparkProperties) - properties.putAll(sparkProperties) - - // Use properties file as fallback for values which have a direct analog to - // arguments in this script. - master = Option(master).orElse(properties.get("spark.master")).orNull - executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull - executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull + /** + * Load arguments from environment variables, Spark properties etc. + */ + private def loadEnvironmentArguments(): Unit = { + master = Option(master) + .orElse(sparkProperties.get("spark.master")) + .orElse(env.get("MASTER")) + .orNull + driverMemory = Option(driverMemory) + .orElse(sparkProperties.get("spark.driver.memory")) + .orElse(env.get("SPARK_DRIVER_MEMORY")) + .orNull + executorMemory = Option(executorMemory) + .orElse(sparkProperties.get("spark.executor.memory")) + .orElse(env.get("SPARK_EXECUTOR_MEMORY")) + .orNull + executorCores = Option(executorCores) + .orElse(sparkProperties.get("spark.executor.cores")) + .orNull totalExecutorCores = Option(totalExecutorCores) - .orElse(properties.get("spark.cores.max")) + .orElse(sparkProperties.get("spark.cores.max")) .orNull - name = Option(name).orElse(properties.get("spark.app.name")).orNull - jars = Option(jars).orElse(properties.get("spark.jars")).orNull - - // This supports env vars in older versions of Spark - master = Option(master).orElse(env.get("MASTER")).orNull + name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull + jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull // Try to set main class from JAR if no --class argument is given @@ -131,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments() = { + private def checkRequiredArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -166,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - override def toString = { + override def toString = { s"""Parsed arguments: | master $master | deployMode $deployMode @@ -174,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile - | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -193,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | jars $jars | verbose $verbose | - |Default properties from $propertiesFile: - |${defaultSparkProperties.mkString(" ", "\n ", "\n")} + |Spark properties used, including those specified through + | --conf and those from the properties file $propertiesFile: + |${sparkProperties.mkString(" ", "\n ", "\n")} """.stripMargin } @@ -327,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d25c29113d6da..0e249e51a77d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -84,11 +84,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { {info.id} {info.name} - {startTime} - {endTime} - {duration} + {startTime} + {endTime} + {duration} {info.sparkUser} - {lastUpdated} + {lastUpdated} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 6ba395be1cc2c..ad7d81747c377 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import akka.actor.ActorRef +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 2ac21186881fa..9d3d7938c6ccb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.master import java.util.Date +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.DriverDescription import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 08a99bbe68578..6ff2aa5244847 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -18,10 +18,12 @@ package org.apache.spark.deploy.master import java.io._ - -import akka.serialization.Serialization +import java.nio.ByteBuffer import org.apache.spark.Logging +import org.apache.spark.serializer.Serializer + +import scala.reflect.ClassTag /** * Stores data in a single on-disk directory with one file per application and worker. @@ -32,65 +34,39 @@ import org.apache.spark.Logging */ private[spark] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serialization: Serializer) extends PersistenceEngine with Logging { + val serializer = serialization.newInstance() new File(dir).mkdir() - override def addApplication(app: ApplicationInfo) { - val appFile = new File(dir + File.separator + "app_" + app.id) - serializeIntoFile(appFile, app) - } - - override def removeApplication(app: ApplicationInfo) { - new File(dir + File.separator + "app_" + app.id).delete() - } - - override def addDriver(driver: DriverInfo) { - val driverFile = new File(dir + File.separator + "driver_" + driver.id) - serializeIntoFile(driverFile, driver) + override def persist(name: String, obj: Object): Unit = { + serializeIntoFile(new File(dir + File.separator + name), obj) } - override def removeDriver(driver: DriverInfo) { - new File(dir + File.separator + "driver_" + driver.id).delete() + override def unpersist(name: String): Unit = { + new File(dir + File.separator + name).delete() } - override def addWorker(worker: WorkerInfo) { - val workerFile = new File(dir + File.separator + "worker_" + worker.id) - serializeIntoFile(workerFile, worker) - } - - override def removeWorker(worker: WorkerInfo) { - new File(dir + File.separator + "worker_" + worker.id).delete() - } - - override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = new File(dir).listFiles().sortBy(_.getName) - val appFiles = sortedFiles.filter(_.getName.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) - val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) - val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, drivers, workers) + override def read[T: ClassTag](prefix: String) = { + val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) + files.map(deserializeFromFile[T]) } private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - - val out = new FileOutputStream(file) + val out = serializer.serializeStream(new FileOutputStream(file)) try { - out.write(serialized) + out.writeObject(value) } finally { out.close() } + } - def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { + def deserializeFromFile[T](file: File): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) try { @@ -99,8 +75,6 @@ private[spark] class FileSystemPersistenceEngine( dis.close() } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] + serializer.deserializeStream(dis).readObject() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index 4433a2ec29be6..cf77c86d760cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -17,30 +17,27 @@ package org.apache.spark.deploy.master -import akka.actor.{Actor, ActorRef} - -import org.apache.spark.deploy.master.MasterMessages.ElectedLeader +import org.apache.spark.annotation.DeveloperApi /** - * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it - * is the only Master serving requests. - * In addition to the API provided, the LeaderElectionAgent will use of the following messages - * to inform the Master of leader changes: - * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]] - * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]] + * :: DeveloperApi :: + * + * A LeaderElectionAgent tracks current master and is a common interface for all election Agents. */ -private[spark] trait LeaderElectionAgent extends Actor { - // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring. - val masterActor: ActorRef +@DeveloperApi +trait LeaderElectionAgent { + val masterActor: LeaderElectable + def stop() {} // to avoid noops in implementations. } -/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ -private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent { - override def preStart() { - masterActor ! ElectedLeader - } +@DeveloperApi +trait LeaderElectable { + def electedLeader() + def revokedLeadership() +} - override def receive = { - case _ => - } +/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ +private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable) + extends LeaderElectionAgent { + masterActor.electedLeader() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2f81d472d7b78..021454e25804c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -50,7 +50,7 @@ private[spark] class Master( port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { + extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() @@ -61,7 +61,6 @@ private[spark] class Master( val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) - val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") val workers = new HashSet[WorkerInfo] @@ -103,7 +102,7 @@ private[spark] class Master( var persistenceEngine: PersistenceEngine = _ - var leaderElectionAgent: ActorRef = _ + var leaderElectionAgent: LeaderElectionAgent = _ private var recoveryCompletionTask: Cancellable = _ @@ -130,23 +129,24 @@ private[spark] class Master( masterMetricsSystem.start() applicationMetricsSystem.start() - persistenceEngine = RECOVERY_MODE match { + val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") - new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf) + val zkFactory = new ZooKeeperRecoveryModeFactory(conf) + (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => - logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) - new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system)) + val fsFactory = new FileSystemRecoveryModeFactory(conf) + (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) + case "CUSTOM" => + val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(conf.getClass) + .newInstance(conf).asInstanceOf[StandaloneRecoveryModeFactory] + (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => - new BlackHolePersistenceEngine() + (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this)) } - - leaderElectionAgent = RECOVERY_MODE match { - case "ZOOKEEPER" => - context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf)) - case _ => - context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) - } + persistenceEngine = persistenceEngine_ + leaderElectionAgent = leaderElectionAgent_ } override def preRestart(reason: Throwable, message: Option[Any]) { @@ -165,7 +165,15 @@ private[spark] class Master( masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() - context.stop(leaderElectionAgent) + leaderElectionAgent.stop() + } + + override def electedLeader() { + self ! ElectedLeader + } + + override def revokedLeadership() { + self ! RevokedLeadership } override def receiveWithLogging = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index e3640ea4f7e64..2e0e1e7036ac8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -17,6 +17,10 @@ package org.apache.spark.deploy.master +import org.apache.spark.annotation.DeveloperApi + +import scala.reflect.ClassTag + /** * Allows Master to persist any state that is necessary in order to recover from a failure. * The following semantics are required: @@ -25,36 +29,70 @@ package org.apache.spark.deploy.master * Given these two requirements, we will have all apps and workers persisted, but * we might not have yet deleted apps or workers that finished (so their liveness must be verified * during recovery). + * + * The implementation of this trait defines how name-object pairs are stored or retrieved. */ -private[spark] trait PersistenceEngine { - def addApplication(app: ApplicationInfo) +@DeveloperApi +trait PersistenceEngine { - def removeApplication(app: ApplicationInfo) + /** + * Defines how the object is serialized and persisted. Implementation will + * depend on the store used. + */ + def persist(name: String, obj: Object) - def addWorker(worker: WorkerInfo) + /** + * Defines how the object referred by its name is removed from the store. + */ + def unpersist(name: String) - def removeWorker(worker: WorkerInfo) + /** + * Gives all objects, matching a prefix. This defines how objects are + * read/deserialized back. + */ + def read[T: ClassTag](prefix: String): Seq[T] - def addDriver(driver: DriverInfo) + final def addApplication(app: ApplicationInfo): Unit = { + persist("app_" + app.id, app) + } - def removeDriver(driver: DriverInfo) + final def removeApplication(app: ApplicationInfo): Unit = { + unpersist("app_" + app.id) + } + + final def addWorker(worker: WorkerInfo): Unit = { + persist("worker_" + worker.id, worker) + } + + final def removeWorker(worker: WorkerInfo): Unit = { + unpersist("worker_" + worker.id) + } + + final def addDriver(driver: DriverInfo): Unit = { + persist("driver_" + driver.id, driver) + } + + final def removeDriver(driver: DriverInfo): Unit = { + unpersist("driver_" + driver.id) + } /** * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) + final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } def close() {} } private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { - override def addApplication(app: ApplicationInfo) {} - override def removeApplication(app: ApplicationInfo) {} - override def addWorker(worker: WorkerInfo) {} - override def removeWorker(worker: WorkerInfo) {} - override def addDriver(driver: DriverInfo) {} - override def removeDriver(driver: DriverInfo) {} - - override def readPersistedData() = (Nil, Nil, Nil) + + override def persist(name: String, obj: Object): Unit = {} + + override def unpersist(name: String): Unit = {} + + override def read[T: ClassTag](name: String): Seq[T] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala new file mode 100644 index 0000000000000..d9d36c1ed5f9f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.JavaSerializer + +/** + * ::DeveloperApi:: + * + * Implementation of this class can be plugged in as recovery mode alternative for Spark's + * Standalone mode. + * + */ +@DeveloperApi +abstract class StandaloneRecoveryModeFactory(conf: SparkConf) { + + /** + * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) + * is handled for recovery. + * + */ + def createPersistenceEngine(): PersistenceEngine + + /** + * Create an instance of LeaderAgent that decides who gets elected as master. + */ + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent +} + +/** + * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual + * recovery is made by restoring from filesystem. + */ +private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf) + extends StandaloneRecoveryModeFactory(conf) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") + + def createPersistenceEngine() = { + logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) + new FileSystemPersistenceEngine(RECOVERY_DIR, new JavaSerializer(conf)) + } + + def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) +} + +private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf) + extends StandaloneRecoveryModeFactory(conf) { + def createPersistenceEngine() = new ZooKeeperPersistenceEngine(new JavaSerializer(conf), conf) + + def createLeaderElectionAgent(master: LeaderElectable) = + new ZooKeeperLeaderElectionAgent(master, conf) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index d221b0f6cc86b..473ddc23ff0f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import akka.actor.ActorRef +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils private[spark] class WorkerInfo( diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 285f9b014e291..8eaa0ad948519 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} -private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, - masterUrl: String, conf: SparkConf) - extends LeaderElectionAgent with LeaderLatchListener with Logging { +private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, + conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, private var leaderLatch: LeaderLatch = _ private var status = LeadershipStatus.NOT_LEADER - override def preStart() { + start() + def start() { logInfo("Starting ZooKeeper LeaderElection agent") zk = SparkCuratorUtil.newClient(conf) leaderLatch = new LeaderLatch(zk, WORKING_DIR) leaderLatch.addListener(this) - leaderLatch.start() } - override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) { - logError("LeaderElectionAgent failed...", reason) - super.preRestart(reason, message) - } - - override def postStop() { + override def stop() { leaderLatch.close() zk.close() } - override def receive = { - case _ => - } - override def isLeader() { synchronized { // could have lost leadership by now. @@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER - masterActor ! ElectedLeader + masterActor.electedLeader() } else if (!isLeader && status == LeadershipStatus.LEADER) { status = LeadershipStatus.NOT_LEADER - masterActor ! RevokedLeadership + masterActor.revokedLeadership() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 834dfedee52ce..96c2139eb02f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -19,72 +19,54 @@ package org.apache.spark.deploy.master import scala.collection.JavaConversions._ -import akka.serialization.Serialization import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.serializer.Serializer +import java.nio.ByteBuffer -class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) +import scala.reflect.ClassTag + + +private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, conf: SparkConf) extends PersistenceEngine with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) - SparkCuratorUtil.mkdir(zk, WORKING_DIR) - - override def addApplication(app: ApplicationInfo) { - serializeIntoFile(WORKING_DIR + "/app_" + app.id, app) - } + val serializer = serialization.newInstance() - override def removeApplication(app: ApplicationInfo) { - zk.delete().forPath(WORKING_DIR + "/app_" + app.id) - } + SparkCuratorUtil.mkdir(zk, WORKING_DIR) - override def addDriver(driver: DriverInfo) { - serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver) - } - override def removeDriver(driver: DriverInfo) { - zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id) + override def persist(name: String, obj: Object): Unit = { + serializeIntoFile(WORKING_DIR + "/" + name, obj) } - override def addWorker(worker: WorkerInfo) { - serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker) + override def unpersist(name: String): Unit = { + zk.delete().forPath(WORKING_DIR + "/" + name) } - override def removeWorker(worker: WorkerInfo) { - zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id) + override def read[T: ClassTag](prefix: String) = { + val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) + file.map(deserializeFromFile[T]).flatten } override def close() { zk.close() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted - val appFiles = sortedFiles.filter(_.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten - val driverFiles = sortedFiles.filter(_.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten - val workerFiles = sortedFiles.filter(_.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten - (apps, drivers, workers) - } - private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.serialize(value) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized.array()) } - def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = { + def deserializeFromFile[T](filename: String): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.deserialize(ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index aba2e20118d7a..28e9662db5da9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -37,12 +37,12 @@ object CommandUtils extends Logging { * The `env` argument is exposed for testing. */ def buildProcessBuilder( - command: Command, - memory: Int, - sparkHome: String, - substituteArguments: String => String, - classPaths: Seq[String] = Seq[String](), - env: Map[String, String] = sys.env): ProcessBuilder = { + command: Command, + memory: Int, + sparkHome: String, + substituteArguments: String => String, + classPaths: Seq[String] = Seq[String](), + env: Map[String, String] = sys.env): ProcessBuilder = { val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) val builder = new ProcessBuilder(commandSeq: _*) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala new file mode 100644 index 0000000000000..d044e1d01d429 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[worker] +class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) + private val blockHandler = new ExternalShuffleBlockHandler(transportConf) + private val transportContext: TransportContext = { + val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler + new TransportContext(transportConf, handler) + } + + private var server: TransportServer = _ + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + } + + def stop() { + if (enabled && server != null) { + server.close() + server = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c4a8ec2e5e7b0..ca262de832e25 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -111,6 +111,9 @@ private[spark] class Worker( val drivers = new HashMap[String, DriverRunner] val finishedDrivers = new HashMap[String, DriverRunner] + // The shuffle service is not actually started unless configured. + val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -154,6 +157,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -186,11 +190,11 @@ private[spark] class Worker( private def retryConnectToMaster() { Utils.tryOrExit { connectionAttemptCount += 1 - logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount") if (registered) { registrationRetryTimer.foreach(_.cancel()) registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { + logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") tryRegisterAllMasters() if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { registrationRetryTimer.foreach(_.cancel()) @@ -419,6 +423,7 @@ private[spark] class Worker( registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) + shuffleService.stop() webUi.stop() metricsSystem.stop() } @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging { cores: Int, memory: Int, masterUrls: Array[String], - workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workDir: String, + workerNumber: Option[Int] = None): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 697154d762d41..3711824a40cfc 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2889f59e33e84..caf4d76713d49 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.ActorSystem +import akka.actor.{Props, ActorSystem} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -78,7 +78,7 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - conf.set("spark.executor.id", "executor." + executorId) + conf.set("spark.executor.id", executorId) private val env = { if (!isLocal) { val port = conf.getInt("spark.executor.port", 0) @@ -86,12 +86,17 @@ private[spark] class Executor( conf, executorId, slaveHostname, port, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -104,6 +109,9 @@ private[spark] class Executor( // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Limit of bytes for total size of results (default is 1GB) + private val maxResultSize = Utils.getMaxResultSize(conf) + // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -128,6 +136,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -152,7 +161,7 @@ private[spark] class Executor( } override def run() { - val startTime = System.currentTimeMillis() + val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -197,7 +206,7 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime + m.executorDeserializeTime = taskStart - deserializeStartTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime m.resultSerializationTime = afterSerialization - beforeSerialization @@ -210,25 +219,27 @@ private[spark] class Executor( val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val serializedResult = { + if (resultSize > maxResultSize) { + logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { - (serializedDirectResult, true) + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + serializedDirectResult } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason @@ -252,7 +263,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 0000000000000..41925f7e97e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 57bc2b40cec44..51b5328cb4c8f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -82,6 +82,12 @@ class TaskMetrics extends Serializable { */ var inputMetrics: Option[InputMetrics] = None + /** + * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much + * data was written are stored here. + */ + var outputMetrics: Option[OutputMetrics] = None + /** * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. * This includes read metrics aggregated over all the task's shuffle dependencies. @@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable { val Memory, Disk, Hadoop, Network = Value } +/** + * :: DeveloperApi :: + * Method by which output data was written. + */ +@DeveloperApi +object DataWriteMethod extends Enumeration with Serializable { + type DataWriteMethod = Value + val Hadoop = Value +} + /** * :: DeveloperApi :: * Metrics about reading input data. @@ -169,6 +185,18 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { var bytesRead: Long = 0L } +/** + * :: DeveloperApi :: + * Metrics about writing output data. + */ +@DeveloperApi +case class OutputMetrics(writeMethod: DataWriteMethod.Value) { + /** + * Total bytes written + */ + var bytesWritten: Long = 0L +} + /** * :: DeveloperApi :: * Metrics pertaining to shuffle data read in a given task. diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala new file mode 100644 index 0000000000000..89b29af2000c8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +/** + * Custom Input Format for reading and splitting flat binary files that contain records, + * each of which are a fixed size in bytes. The fixed record size is specified through + * a parameter recordLength in the Hadoop configuration. + */ +private[spark] object FixedLengthBinaryInputFormat { + /** Property name to set in Hadoop JobConfs for record length */ + val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength" + + /** Retrieves the record length property from a Hadoop configuration */ + def getRecordLength(context: JobContext): Int = { + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt + } +} + +private[spark] class FixedLengthBinaryInputFormat + extends FileInputFormat[LongWritable, BytesWritable] { + + private var recordLength = -1 + + /** + * Override of isSplitable to ensure initial computation of the record length + */ + override def isSplitable(context: JobContext, filename: Path): Boolean = { + if (recordLength == -1) { + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + } + if (recordLength <= 0) { + println("record length is less than 0, file cannot be split") + false + } else { + true + } + } + + /** + * This input format overrides computeSplitSize() to make sure that each split + * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader + * will start at the first byte of a record, and the last byte will the last byte of a record. + */ + override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = { + val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize) + // If the default size is less than the length of a record, make it equal to it + // Otherwise, make sure the split size is as close to possible as the default size, + // but still contains a complete set of records, with the first record + // starting at the first byte in the split and the last record ending with the last byte + if (defaultSize < recordLength) { + recordLength.toLong + } else { + (Math.floor(defaultSize / recordLength) * recordLength).toLong + } + } + + /** + * Create a FixedLengthBinaryRecordReader + */ + override def createRecordReader(split: InputSplit, context: TaskAttemptContext) + : RecordReader[LongWritable, BytesWritable] = { + new FixedLengthBinaryRecordReader + } +} diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala new file mode 100644 index 0000000000000..36a1e5d475f46 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.IOException + +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileSplit + +/** + * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. + * It uses the record length set in FixedLengthBinaryInputFormat to + * read one record at a time from the given InputSplit. + * + * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value. + * + * key = record index (Long) + * value = the record itself (BytesWritable) + */ +private[spark] class FixedLengthBinaryRecordReader + extends RecordReader[LongWritable, BytesWritable] { + + private var splitStart: Long = 0L + private var splitEnd: Long = 0L + private var currentPosition: Long = 0L + private var recordLength: Int = 0 + private var fileInputStream: FSDataInputStream = null + private var recordKey: LongWritable = null + private var recordValue: BytesWritable = null + + override def close() { + if (fileInputStream != null) { + fileInputStream.close() + } + } + + override def getCurrentKey: LongWritable = { + recordKey + } + + override def getCurrentValue: BytesWritable = { + recordValue + } + + override def getProgress: Float = { + splitStart match { + case x if x == splitEnd => 0.0.toFloat + case _ => Math.min( + ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0 + ).toFloat + } + } + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) { + // the file input + val fileSplit = inputSplit.asInstanceOf[FileSplit] + + // the byte position this fileSplit starts at + splitStart = fileSplit.getStart + + // splitEnd byte marker that the fileSplit ends at + splitEnd = splitStart + fileSplit.getLength + + // the actual file we will be reading from + val file = fileSplit.getPath + // job configuration + val job = context.getConfiguration + // check compression + val codec = new CompressionCodecFactory(job).getCodec(file) + if (codec != null) { + throw new IOException("FixedLengthRecordReader does not support reading compressed files") + } + // get the record length + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + // get the filesystem + val fs = file.getFileSystem(job) + // open the File + fileInputStream = fs.open(file) + // seek to the splitStart position + fileInputStream.seek(splitStart) + // set our current position + currentPosition = splitStart + } + + override def nextKeyValue(): Boolean = { + if (recordKey == null) { + recordKey = new LongWritable() + } + // the key is a linear index of the record, given by the + // position the record starts divided by the record length + recordKey.set(currentPosition / recordLength) + // the recordValue to place the bytes into + if (recordValue == null) { + recordValue = new BytesWritable(new Array[Byte](recordLength)) + } + // read a record if the currentPosition is less than the split end + if (currentPosition < splitEnd) { + // setup a buffer to store the record + val buffer = recordValue.getBytes + fileInputStream.readFully(buffer) + // update our current position + currentPosition = currentPosition + recordLength + // return true + return true + } + false + } +} diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala new file mode 100644 index 0000000000000..457472547fcbb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.JavaConversions._ + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} + +import org.apache.spark.annotation.Experimental + +/** + * A general format for reading whole files in as streams, byte arrays, + * or other functions to be added + */ +private[spark] abstract class StreamFileInputFormat[T] + extends CombineFileInputFormat[String, T] +{ + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + + /** + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API + * which is set through setMaxSplitSize + */ + def setMinPartitions(context: JobContext, minPartitions: Int) { + val files = listStatus(context) + val totalLen = files.map { file => + if (file.isDir) 0L else file.getLen + }.sum + + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + super.setMaxSplitSize(maxSplitSize) + } + + def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T] + +} + +/** + * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] + * to reading files out as streams + */ +private[spark] abstract class StreamBasedRecordReader[T]( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, T] { + + // True means the current file has been processed, then skip it. + private var processed = false + + private var key = "" + private var value: T = null.asInstanceOf[T] + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = new PortableDataStream(split, context, index) + value = parseStream(fileIn) + fileIn.close() // if it has not been open yet, close does nothing + key = fileIn.getPath + processed = true + true + } else { + false + } + } + + /** + * Parse the stream (and close it afterwards) and return the value as in type T + * @param inStream the stream to be read in + * @return the data formatted as + */ + def parseStream(inStream: PortableDataStream): T +} + +/** + * Reads the record in directly as a stream for other objects to manipulate and handle + */ +private[spark] class StreamRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends StreamBasedRecordReader[PortableDataStream](split, context, index) { + + def parseStream(inStream: PortableDataStream): PortableDataStream = inStream +} + +/** + * The format for the PortableDataStream files + */ +private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + new CombineFileRecordReader[String, PortableDataStream]( + split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) + } +} + +/** + * A class that allows DataStreams to be serialized and moved around by not creating them + * until they need to be read + * @note TaskAttemptContext is not serializable resulting in the confBytes construct + * @note CombineFileSplit is not serializable resulting in the splitBytes construct + */ +@Experimental +class PortableDataStream( + @transient isplit: CombineFileSplit, + @transient context: TaskAttemptContext, + index: Integer) + extends Serializable { + + // transient forces file to be reopened after being serialization + // it is also used for non-serializable classes + + @transient private var fileIn: DataInputStream = null + @transient private var isOpen = false + + private val confBytes = { + val baos = new ByteArrayOutputStream() + context.getConfiguration.write(new DataOutputStream(baos)) + baos.toByteArray + } + + private val splitBytes = { + val baos = new ByteArrayOutputStream() + isplit.write(new DataOutputStream(baos)) + baos.toByteArray + } + + @transient private lazy val split = { + val bais = new ByteArrayInputStream(splitBytes) + val nsplit = new CombineFileSplit() + nsplit.readFields(new DataInputStream(bais)) + nsplit + } + + @transient private lazy val conf = { + val bais = new ByteArrayInputStream(confBytes) + val nconf = new Configuration() + nconf.readFields(new DataInputStream(bais)) + nconf + } + /** + * Calculate the path name independently of opening the file + */ + @transient private lazy val path = { + val pathp = split.getPath(index) + pathp.toString + } + + /** + * Create a new DataInputStream from the split and context + */ + def open(): DataInputStream = { + if (!isOpen) { + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(conf) + fileIn = fs.open(pathp) + isOpen = true + } + fileIn + } + + /** + * Read the file as a byte array + */ + def toArray(): Array[Byte] = { + open() + val innerBuffer = ByteStreams.toByteArray(fileIn) + close() + innerBuffer + } + + /** + * Close the file (if it is currently open) + */ + def close() = { + if (isOpen) { + try { + fileIn.close() + isOpen = false + } catch { + case ioe: java.io.IOException => // do nothing + } + } + } + + def getPath(): String = path +} + diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 4cb450577796a..183bce3d8d8d3 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -48,9 +48,10 @@ private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[Str } /** - * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API. + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API, + * which is set through setMaxSplitSize */ - def setMaxSplitSize(context: JobContext, minPartitions: Int) { + def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context) val totalLen = files.map { file => if (file.isDir) 0L else file.getLen diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala similarity index 79% rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 0c47afae54c8b..21b782edd2a9e 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -15,15 +15,24 @@ * limitations under the License. */ -package org.apache.hadoop.mapred +package org.apache.spark.mapred -private[apache] +import java.lang.reflect.Modifier + +import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} + +private[spark] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. + // Make it accessible if it's not in order to access it. + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } @@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil { val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) + // See above + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala similarity index 96% rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 1fca5729c6092..3340673f91156 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.hadoop.mapreduce +package org.apache.spark.mapreduce import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -private[apache] +private[spark] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index b083f465334fe..dcbda5a8515dd 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,16 +20,16 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Await, Future} +import scala.concurrent.{Promise, Await, Future} import scala.concurrent.duration.Duration import org.apache.spark.Logging import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils +import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} +import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} private[spark] -abstract class BlockTransferService extends Closeable with Logging { +abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -60,10 +60,11 @@ abstract class BlockTransferService extends Closeable with Logging { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - def fetchBlocks( - hostName: String, + override def fetchBlocks( + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit /** @@ -72,6 +73,7 @@ abstract class BlockTransferService extends Closeable with Logging { def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -81,43 +83,23 @@ abstract class BlockTransferService extends Closeable with Logging { * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = { + def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { // A monitor for the thread to wait on. - val lock = new Object - @volatile var result: Either[ManagedBuffer, Throwable] = null - fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { - lock.synchronized { - result = Right(exception) - lock.notify() + val result = Promise[ManagedBuffer]() + fetchBlocks(host, port, execId, Array(blockId), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + result.failure(exception) } - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - lock.synchronized { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { val ret = ByteBuffer.allocate(data.size.toInt) ret.put(data.nioByteBuffer()) ret.flip() - result = Left(new NioManagedBuffer(ret)) - lock.notify() + result.success(new NioManagedBuffer(ret)) } - } - }) + }) - // Sleep until result is no longer null - lock.synchronized { - while (result == null) { - try { - lock.wait() - } catch { - case e: InterruptedException => - } - } - } - - result match { - case Left(data) => data - case Right(e) => throw e - } + Await.result(result.future, Duration.Inf) } /** @@ -129,9 +111,10 @@ abstract class BlockTransferService extends Closeable with Logging { def uploadBlockSync( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { - Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala deleted file mode 100644 index 8c5ffd8da6bbb..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer -import java.util - -import org.apache.spark.{SparkConf, Logging} -import org.apache.spark.network.BlockFetchingListener -import org.apache.spark.network.netty.NettyMessages._ -import org.apache.spark.serializer.{JavaSerializer, Serializer} -import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient} -import org.apache.spark.storage.BlockId -import org.apache.spark.util.Utils - -/** - * Responsible for holding the state for a request for a single set of blocks. This assumes that - * the chunks will be returned in the same order as requested, and that there will be exactly - * one chunk per block. - * - * Upon receipt of any block, the listener will be called back. Upon failure part way through, - * the listener will receive a failure callback for each outstanding block. - */ -class NettyBlockFetcher( - serializer: Serializer, - client: TransportClient, - blockIds: Seq[String], - listener: BlockFetchingListener) - extends Logging { - - require(blockIds.nonEmpty) - - private val ser = serializer.newInstance() - - private var streamHandle: ShuffleStreamHandle = _ - - private val chunkCallback = new ChunkReceivedCallback { - // On receipt of a chunk, pass it upwards as a block. - def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { - listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) - } - - // On receipt of a failure, fail every block from chunkIndex onwards. - def onFailure(chunkIndex: Int, e: Throwable): Unit = { - blockIds.drop(chunkIndex).foreach { blockId => - listener.onBlockFetchFailure(blockId, e); - } - } - } - - /** Begins the fetching process, calling the listener with every block fetched. */ - def start(): Unit = { - // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. - client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), - new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { - try { - streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) - logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") - - // Immediately request all chunks -- we expect that the total size of the request is - // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. - for (i <- 0 until streamHandle.numChunks) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback) - } - } catch { - case e: Exception => - logError("Failed while starting block fetches", e) - blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) - } - } - - override def onFailure(e: Throwable): Unit = { - logError("Failed while starting block fetches", e) - blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) - } - }) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 02c657e1d61b5..b089da8596e2b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,58 +19,55 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer +import scala.collection.JavaConversions._ + import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer -import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.network.client.{TransportClient, RpcResponseCallback} -import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} -import org.apache.spark.storage.{StorageLevel, BlockId} - -import scala.collection.JavaConversions._ - -object NettyMessages { - - /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ - case class OpenBlocks(blockIds: Seq[BlockId]) - - /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ - case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) - - /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ - case class ShuffleStreamHandle(streamId: Long, numChunks: Int) -} +import org.apache.spark.storage.{BlockId, StorageLevel} /** * Serves requests to open blocks by simply registering one chunk per block requested. + * Handles opening and uploading arbitrary BlockManager blocks. + * + * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk + * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( serializer: Serializer, - streamManager: DefaultStreamManager, blockManager: BlockDataManager) extends RpcHandler with Logging { - import NettyMessages._ + private val streamManager = new OneForOneStreamManager() override def receive( client: TransportClient, messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { - val ser = serializer.newInstance() - val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) logTrace(s"Received request: $message") message match { - case OpenBlocks(blockIds) => - val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + case openBlocks: OpenBlocks => + val blocks: Seq[ManagedBuffer] = + openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(blocks.iterator) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess( - ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) - case UploadBlock(blockId, blockData, level) => - blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + case uploadBlock: UploadBlock => + // StorageLevel is serialized as bytes using our JavaSerializer. + val level: StorageLevel = + serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) + blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) responseContext.onSuccess(new Array[Byte](0)) } } + + override def getStreamManager(): StreamManager = streamManager } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 38a3e945155e8..f8a7f640689a2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,15 +17,17 @@ package org.apache.spark.network.netty -import scala.concurrent.{Promise, Future} +import scala.collection.JavaConversions._ +import scala.concurrent.{Future, Promise} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory} -import org.apache.spark.network.netty.NettyMessages.UploadBlock +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.util.{ConfigProvider, TransportConf} +import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -33,34 +35,59 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { - // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - val serializer = new JavaSerializer(conf) +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService { - // Create a TransportConfig using SparkConf. - private[this] val transportConf = new TransportConf( - new ConfigProvider { override def get(name: String) = conf.get(name) }) + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ + private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val streamManager = new DefaultStreamManager - val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) - transportContext = new TransportContext(transportConf, streamManager, rpcHandler) - clientFactory = transportContext.createClientFactory() + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) + } + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() + appId = conf.getAppId + logInfo("Server created on " + server.getPort) } override def fetchBlocks( - hostname: String, + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit = { + logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { - val client = clientFactory.createClient(hostname, port) - new NettyBlockFetcher(serializer, client, blockIds, listener).start() + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) @@ -75,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { val result = Promise[Unit]() val client = clientFactory.createClient(hostname, port) + // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded + // using our binary protocol. + val levelBytes = serializer.newInstance().serialize(level).array() + // Convert or copy nio buffer into array in order to serialize it. val nioBuffer = blockData.nioByteBuffer() val array = if (nioBuffer.hasArray) { @@ -91,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { data } - val ser = serializer.newInstance() - client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") @@ -107,5 +138,8 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { result.future } - override def close(): Unit = server.close() + override def close(): Unit = { + server.close() + clientFactory.close() + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala new file mode 100644 index 0000000000000..9fa4fa77b8817 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.{TransportConf, ConfigProvider} + +/** + * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + */ +object SparkTransportConf { + def fromSparkConf(conf: SparkConf): TransportConf = { + new TransportConf(new ConfigProvider { + override def get(name: String): String = conf.get(name) + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 4f6f5e235811d..c2d9578be7ebb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,12 +23,13 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} + private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8408b75bb4d65..f198aa8564a54 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils import scala.util.Try @@ -600,7 +601,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -634,7 +635,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -778,7 +779,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 11793ea92adb1..b2aec160635c7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -79,13 +80,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } override def fetchBlocks( - hostName: String, + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit = { checkInit() - val cmId = new ConnectionManagerId(hostName, port) + val cmId = new ConnectionManagerId(host, port) val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) }) @@ -135,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala new file mode 100644 index 0000000000000..6e66ddbdef788 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{ Configurable, Configuration } +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.spark.input.StreamFileInputFormat +import org.apache.spark.{ Partition, SparkContext } + +private[spark] class BinaryFileRDD[T]( + sc: SparkContext, + inputFormatClass: Class[_ <: StreamFileInputFormat[T]], + keyClass: Class[String], + valueClass: Class[T], + @transient conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 2673ec22509e9..fffa1911f5bc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds "Attempted to use %s after its blocks have been removed!".format(toString)) } } + + protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = { + locations_ + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 946fb5616d3ec..a157e36e2286e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -211,20 +211,11 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) @@ -234,6 +225,18 @@ class HadoopRDD[K, V]( if (bytesReadCallback.isDefined) { context.taskMetrics.inputMetrics = Some(inputMetrics) } + + var reader: RecordReader[K, V] = null + val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener{ context => closeIfNeeded() } + val key: K = reader.createKey() + val value: V = reader.createValue() + var recordsSinceMetricsUpdate = 0 override def getNext() = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 324563248793c..e55d03d391e03 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil @@ -107,20 +108,10 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) @@ -131,6 +122,18 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) var havePair = false @@ -263,7 +266,7 @@ private[spark] class WholeTextFileRDD( case _ => } val jobContext = newJobContext(conf, jobId) - inputFormat.setMaxSplitSize(jobContext, minPartitions) + inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index da89f634abaea..8c2c959e73bb6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -28,18 +28,20 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} +RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -961,30 +963,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val hadoopContext = newTaskAttemptContext(config, attemptId) val format = outfmt.newInstance format match { - case c: Configurable => c.setConf(wrappedConf.value) + case c: Configurable => c.setConf(config) case _ => () } val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) + + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { + var recordsWritten = 0L while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close(hadoopContext) } committer.commitTask(hadoopContext) + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } 1 } : Int @@ -1005,6 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf) { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf + val wrappedConf = new SerializableWritable(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1032,27 +1045,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { + var recordsWritten = 0L while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close() } writer.commit() + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } } self.context.runJob(self, writeToFile) writer.commitJob() } + private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) + : (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) + .map(new Path(_)) + .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) + if (bytesWrittenCallback.isDefined) { + context.taskMetrics.outputMetrics = Some(outputMetrics) + } + (outputMetrics, bytesWrittenCallback) + } + + private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], + outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 + && bytesWrittenCallback.isDefined) { + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + } + } + /** * Return an RDD with the keys of each tuple. */ @@ -1069,3 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) } + +private[spark] object PairRDDFunctions { + val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b7f125d01dfaf..716f2dd17733b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -43,7 +43,8 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, + SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -375,7 +376,8 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed) + new PartitionwiseSampledRDD[T, T]( + this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) }.toArray } @@ -1094,7 +1096,7 @@ abstract class RDD[T: ClassTag]( } /** - * Returns the top K (largest) elements from this RDD as defined by the specified + * Returns the top k (largest) elements from this RDD as defined by the specified * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) @@ -1104,14 +1106,14 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements */ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse) /** - * Returns the first K (smallest) elements from this RDD as defined by the specified + * Returns the first k (smallest) elements from this RDD as defined by the specified * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ @@ -1122,7 +1124,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f81fa6d8089fc..22449517d100f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -124,6 +124,9 @@ class DAGScheduler( /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ + private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -1050,7 +1053,7 @@ class DAGScheduler( logInfo("Resubmitted " + task + ", so marking it as still running") stage.pendingTasks += task - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) @@ -1060,11 +1063,13 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure")) + markStageAsFinished(failedStage, Some(failureMessage)) runningStages -= failedStage } - if (failedStages.isEmpty && eventProcessActor != null) { + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty && eventProcessActor != null) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. @@ -1086,10 +1091,10 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - case ExceptionFailure(className, description, stackTrace, metrics) => + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1106,25 +1111,35 @@ class DAGScheduler( * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed + * occurred, in which case we presume all shuffle data related to this executor to be lost. + * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { + private[scheduler] def handleExecutorLost( + execId: String, + fetchFailed: Boolean, + maybeEpoch: Option[Long] = None) { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) - } - if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + + if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) { + // TODO: This will be really slow if we keep accumulating shuffle map stages + for ((shuffleId, stage) <- shuffleToMapStage) { + stage.removeOutputsOnExecutor(execId) + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementEpoch() + } + clearCacheLocs() } - clearCacheLocs() } else { logDebug("Additional executor lost message for " + execId + "(epoch " + currentEpoch + ")") @@ -1382,7 +1397,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId) => - dagScheduler.handleExecutorLost(execId) + dagScheduler.handleExecutorLost(execId, fetchFailed = false) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 100c9ba9b7809..597dbc884913c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -142,7 +142,7 @@ private[spark] object EventLoggingListener extends Logging { val SPARK_VERSION_PREFIX = "SPARK_VERSION_" val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort) + val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) // A cache for compression codecs to avoid creating the same codec many times private val codecMap = new mutable.HashMap[String, CompressionCodec] diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 54904bffdf10b..3bb54855bae44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " INPUT_BYTES=" + metrics.bytesRead case None => "" } + val outputMetrics = taskMetrics.outputMetrics match { + case Some(metrics) => + " OUTPUT_BYTES=" + metrics.bytesWritten + case None => "" + } val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { case Some(metrics) => " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + @@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + shuffleReadMetrics + writeMetrics) } @@ -215,7 +220,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 071568cdfb429..cc13f57a49b89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -102,6 +102,11 @@ private[spark] class Stage( } } + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { @@ -131,4 +136,9 @@ private[spark] class Stage( override def toString = "Stage " + id override def hashCode(): Int = id + + override def equals(other: Any): Boolean = other match { + case stage: Stage => stage != null && stage.id == id + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 11c19eeb6e42c..1f114a0207f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.Utils private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] -case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) + extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 3f345ceeaaf7a..819b51e12ad8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => + val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => + if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + return + } + (directResult, serializedData.limit()) + case IndirectTaskResult(blockId, size) => + if (!taskSetManager.canFetchMoreResults(size)) { + // dropped by executor if size is larger than maxResultSize + sparkEnv.blockManager.master.removeBlock(blockId) + return + } logDebug("Fetching indirect task result for TID %s".format(tid)) scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) @@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult + (deserializedResult, size) } - result.metrics.resultSize = serializedData.limit() + + result.metrics.resultSize = size scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => @@ -93,7 +103,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } catch { case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastropic if we can't + // Log an error but keep going here -- the task failed, so not catastrophic if we can't // deserialize the reason. val loader = Utils.getContextOrSparkClassLoader logError( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index a129a434c9a1a..f095915352b17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. - * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks + * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks * for a single SparkContext. These schedulers get sets of tasks submitted to them from the * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running * them, retrying if there are failures, and mitigating stragglers. They return events to the @@ -41,7 +41,7 @@ private[spark] trait TaskScheduler { // Invoked after system has successfully initialized (typically in spark context). // Yarn uses this to bootstrap allocation of resources based on preferred locations, - // wait for slave registerations, etc. + // wait for slave registrations, etc. def postStartHook() { } // Disconnect from the cluster. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a6c23fc85a1b0..d8fb640350343 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -23,13 +23,12 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min +import scala.math.{min, max} import org.apache.spark._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -68,6 +67,9 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) + // Limit of bytes for total size of results (default is 1GB) + val maxResultSize = Utils.getMaxResultSize(conf) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -89,6 +91,8 @@ private[spark] class TaskSetManager( var stageId = taskSet.stageId var name = "TaskSet_" + taskSet.stageId.toString var parent: Pool = null + var totalResultSize = 0L + var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] override def runningTasks = runningTasksSet.size @@ -515,12 +519,33 @@ private[spark] class TaskSetManager( index } + /** + * Marks the task as getting result and notifies the DAG Scheduler + */ def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) } + /** + * Check whether has enough quota to fetch the result with `size` bytes + */ + def canFetchMoreResults(size: Long): Boolean = synchronized { + totalResultSize += size + calculatedTasks += 1 + if (maxResultSize > 0 && totalResultSize > maxResultSize) { + val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(maxResultSize)})" + logError(msg) + abort(msg) + false + } else { + true + } + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ @@ -687,10 +712,11 @@ private[spark] class TaskSetManager( addPendingTask(index, readding=true) } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage. + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, + // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask]) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -706,7 +732,7 @@ private[spark] class TaskSetManager( } // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d8c0e2f66df01..5289661eb896b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() @@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Build a Mesos resource protobuf object */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8e2faff90f9b2..c5f3493477bc5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() @@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Turn a Spark TaskDescription into a Mesos task */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 58b78f041cd85..c0264836de738 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import akka.actor.{Actor, ActorRef, Props} -import org.apache.spark.{Logging, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} @@ -47,7 +47,7 @@ private[spark] class LocalActor( private var freeCores = totalCores - private val localExecutorId = "localhost" + private val localExecutorId = SparkContext.DRIVER_IDENTIFIER private val localExecutorHostname = "localhost" val executor = new Executor( diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 71c08e9d5a8c3..be184464e0ae9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.util.Utils /** * Failed to fetch a shuffle block. The executor catches this exception and propagates it @@ -30,13 +31,22 @@ private[spark] class FetchFailedException( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, - reduceId: Int) - extends Exception { + reduceId: Int, + message: String, + cause: Throwable = null) + extends Exception(message, cause) { - override def getMessage: String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + def this( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int, + cause: Throwable) { + this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + Utils.exceptionString(this)) } /** @@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId) { - - override def getMessage: String = message -} + extends FetchFailedException(null, shuffleId, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 1fb5b2c4546bd..7de2f9cbb2866 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ @@ -62,11 +63,14 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ - +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager with Logging { + private val transportConf = SparkTransportConf.fromSparkConf(conf) + private lazy val blockManager = SparkEnv.get.blockManager // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -181,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf) val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) if (segmentOpt.isDefined) { val segment = segmentOpt.get - return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length) + return new FileSegmentManagedBuffer( + transportConf, segment.file, segment.offset, segment.length) } } throw new IllegalStateException("Failed to find shuffle block: " + blockId) } else { val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(file, 0, file.length) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index e9805c9c134b5..b292587d37028 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -22,8 +22,9 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ /** @@ -35,11 +36,15 @@ import org.apache.spark.storage._ * as the filename postfix for data file, and ".index" as the filename postfix for index file. * */ +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { private lazy val blockManager = SparkEnv.get.blockManager + private val transportConf = SparkTransportConf.fromSparkConf(conf) + /** * Mapping to a single shuffleBlockId with reduce ID 0. * */ @@ -107,6 +112,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { val offset = in.readLong() val nextOffset = in.readLong() new FileSegmentManagedBuffer( + transportConf, getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 6cf9305977a3c..e3e7434df45b0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.serializer.Serializer @@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Some(block) => { + case Success(block) => { block.asInstanceOf[Iterator[T]] } - case None => { + case Failure(e) => { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") + "Failed to get block " + blockId + ", which is not a shuffle block", e) } } } @@ -74,7 +75,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, - SparkEnv.get.blockTransferService, + SparkEnv.get.blockManager.shuffleClient, blockManager, blocksByAddress, serializer, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 746ed33b54c00..183a30373b28c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } - MapStatus(blockManager.blockManagerId, sizes) + MapStatus(blockManager.shuffleServerId, sizes) } private def revertWrites(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b727438ae7e47..bda30a56d808e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { - private val indexShuffleBlockManager = new IndexShuffleBlockManager() + private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 927481b72cf4f..d75f9d7311fad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 8df5ec6bde184..1f012941c85ab 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { def name = "rdd_" + rddId + "_" + splitIndex } +// Format of the shuffle block ids (including data and index) should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 58510d7232436..39434f473a9d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,9 +21,9 @@ import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Future} import scala.util.Random import akka.actor.{ActorSystem, Props} @@ -34,8 +34,13 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} +import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -52,6 +57,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -61,11 +72,10 @@ private[spark] class BlockManager( val conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) + blockTransferService: BlockTransferService, + securityManager: SecurityManager) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -85,8 +95,37 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + private[spark] + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + + // Port used by the external shuffle service. In Yarn mode, this may be already be + // set through the Hadoop configuration as the server is launched in the Yarn NM. + private val externalShuffleServicePort = + Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + + // Check that we're not using external shuffle service with consolidated shuffle files. + if (externalShuffleServiceEnabled + && conf.getBoolean("spark.shuffle.consolidateFiles", false) + && shuffleManager.isInstanceOf[HashShuffleManager]) { + throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" + + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " + + " switch to sort-based shuffle.") + } + + var blockManagerId: BlockManagerId = _ + + // Address of the server that serves this executor's shuffle files. This is either an external + // service, or just our own Executor's BlockManager. + private[spark] var shuffleServerId: BlockManagerId = _ + + // Client to read other executors' shuffle files. This is either an external service, or just the + // standard BlockTranserService to directly connect to other Executors. + private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, + securityManager.isAuthenticationEnabled()) + } else { + blockTransferService + } // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) @@ -116,8 +155,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -136,17 +173,65 @@ private[spark] class BlockManager( conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) = { + blockTransferService: BlockTransferService, + securityManager: SecurityManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + + // Register Executors' configuration with the local shuffle service, if one should exist. + if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { + registerWithExternalShuffleServer() + } + } + + private def registerWithExternalShuffleServer() { + logInfo("Registering executor with local external shuffle service.") + val shuffleConfig = new ExecutorShuffleInfo( + diskBlockManager.localDirs.map(_.toString), + diskBlockManager.subDirsPerLocalDir, + shuffleManager.getClass.getName) + + val MAX_ATTEMPTS = 3 + val SLEEP_TIME_SECS = 5 + + for (i <- 1 to MAX_ATTEMPTS) { + try { + // Synchronous and will throw an exception if we cannot connect. + shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer( + shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig) + return + } catch { + case e: Exception if i < MAX_ATTEMPTS => + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) + Thread.sleep(SLEEP_TIME_SECS * 1000) + } + } } /** @@ -506,7 +591,7 @@ private[spark] class BlockManager( for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { @@ -855,7 +940,7 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel) logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer @@ -1113,6 +1198,10 @@ private[spark] class BlockManager( def stop(): Unit = { blockTransferService.close() + if (shuffleClient ne blockTransferService) { + // Closing should be idempotent, but maybe not for the NioBlockTransferService. + shuffleClient.close() + } diskBlockManager.stop() actorSystem.stop(slaveActor) blockInfo.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 259f423c73e6b..b177a59c721df 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -59,7 +60,7 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = (executorId == "") + def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e41..b63c7f191155c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5e375a2553979..685b2e11440fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 291ddfcc113ac..3f32099d08cc9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 99e925328a4b9..58fba54710510 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) + private[spark] + val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - val localDirs: Array[File] = createLocalDirs(conf) + private[spark] val localDirs: Array[File] = createLocalDirs(conf) if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) @@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon addShutdownHook() + /** Looks up a file by hashing it into one of our local subdirectories. */ + // This method should be kept in sync with + // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile(). def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -159,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { - localDirs.foreach { localDir => - if (localDir.isDirectory() && localDir.exists()) { - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case e: Exception => - logError(s"Exception while deleting local spark dir: $localDir", e) + // Only perform cleanup if an external service is not serving our shuffle files. + if (!blockManager.externalShuffleServiceEnabled) { + localDirs.foreach { localDir => + if (localDir.isDirectory() && localDir.exists()) { + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case e: Exception => + logError(s"Exception while deleting local spark dir: $localDir", e) + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d6f3bf003a9d..6b1f57a069431 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,9 +20,11 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.Serializer import org.apache.spark.util.{CompletionIterator, Utils} @@ -38,8 +40,8 @@ import org.apache.spark.util.{CompletionIterator, Utils} * using too much memory. * * @param context [[TaskContext]], used for metrics update - * @param blockTransferService [[BlockTransferService]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. @@ -49,12 +51,12 @@ import org.apache.spark.util.{CompletionIterator, Utils} private[spark] final class ShuffleBlockFetcherIterator( context: TaskContext, - blockTransferService: BlockTransferService, + shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { import ShuffleBlockFetcherIterator._ @@ -90,7 +92,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: FetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that @@ -117,16 +119,18 @@ final class ShuffleBlockFetcherIterator( private[this] def cleanup() { isZombie = true // Release the current buffer if necessary - if (currentResult != null && !currentResult.failed) { - currentResult.buf.release() + currentResult match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => } // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() - if (!result.failed) { - result.buf.release() + result match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => } } } @@ -140,7 +144,8 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val blockIds = req.blocks.map(_._1.toString) - blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + val address = req.address + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, @@ -149,7 +154,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) shuffleMetrics.remoteBytesRead += buf.size shuffleMetrics.remoteBlocksFetched += 1 } @@ -158,7 +163,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FetchResult(BlockId(blockId), -1, null)) + results.put(new FailureFetchResult(BlockId(blockId), e)) } } ) @@ -179,7 +184,7 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - if (address == blockManager.blockManagerId) { + if (address.executorId == blockManager.blockManagerId.executorId) { // Filter out zero-sized blocks localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size @@ -229,12 +234,12 @@ final class ShuffleBlockFetcherIterator( val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 buf.retain() - results.put(new FetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(blockId, -1, null)) + results.put(new FailureFetchResult(blockId, e)) return } } @@ -265,15 +270,17 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Option[Iterator[Any]]) = { + override def next(): (BlockId, Try[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (!result.failed) { - bytesInFlight -= result.size + + result match { + case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case _ => } // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && @@ -281,20 +288,21 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { - None - } else { - val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream()) - val iter = serializer.newInstance().deserializeStream(is).asIterator - Some(CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - result.buf.release() - })) + val iteratorTry: Try[Iterator[Any]] = result match { + case FailureFetchResult(_, e) => Failure(e) + case SuccessFetchResult(blockId, _, buf) => { + val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Success(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + buf.release() + })) + } } - (result.blockId, iteratorOpt) + (result.blockId, iteratorTry) } } @@ -313,14 +321,30 @@ object ShuffleBlockFetcherIterator { } /** - * Result of a fetch from a remote block. A failure is represented as size == -1. + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult { + val blockId: BlockId + } + + /** + * Result of a fetch from a remote block successfully. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. -1 if failure is present. - * @param buf [[ManagedBuffer]] for the content. null is error. + * Note that this is NOT the exact bytes. + * @param buf [[ManagedBuffer]] for the content. */ - case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) { - def failed: Boolean = size == -1 - if (failed) assert(buf == null) else assert(buf != null) + private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + extends FetchResult { + require(buf != null) + require(size >= 0) } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param e the failure exception + */ + private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index d9066f766476e..def49e80a3605 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import scala.collection.mutable +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { - val execId = formatExecutorId(info.executorId) val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) if (updatedBlocks.length > 0) { - updateStorageStatus(execId, updatedBlocks) + updateStorageStatus(info.executorId, updatedBlocks) } } } @@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener { } } - /** - * In the local mode, there is a discrepancy between the executor ID according to the - * task ("localhost") and that according to SparkEnv (""). In the UI, this - * results in duplicate rows for the same executor. Thus, in this mode, we aggregate - * these two rows and use the executor ID of "" to be consistent. - */ - def formatExecutorId(execId: String): String = { - if (execId == "localhost") "" else execId - } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 9ced9b8107ebf..6f446c5a95a0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,11 +24,28 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" + val TASK_DESERIALIZATION_TIME = + """Time spent deserializating the task closure on the executor.""" + val INPUT = "Bytes read from Hadoop or from Spark storage." + val OUTPUT = "Bytes written to Hadoop." + val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = """Bytes read from remote executors. Typically less than shuffle write bytes because this does not include shuffle data read locally.""" + + val GETTING_RESULT_TIME = + """Time that the driver spends fetching task results from workers. If this is large, consider + decreasing the amount of data returned from each task.""" + + val RESULT_SERIALIZATION_TIME = + """Time spent serializing the task result on the executor before sending it back to the + driver.""" + + val GC_TIME = + """Time that the executor spent paused for Java garbage collection while the task was + running.""" } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 76714b1e6964f..3312671b6f885 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,13 +20,13 @@ package org.apache.spark.ui import java.text.SimpleDateFormat import java.util.{Locale, Date} -import scala.xml.{Text, Node} +import scala.xml.{Node, Text} import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable" + val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -160,6 +160,8 @@ private[spark] object UIUtils extends Logging { + + } /** Returns a spark page with correctly formatted headers */ @@ -240,7 +242,8 @@ private[spark] object UIUtils extends Logging { generateDataRow: T => Seq[Node], data: Iterable[T], fixedWidth: Boolean = false, - id: Option[String] = None): Seq[Node] = { + id: Option[String] = None, + headerClasses: Seq[String] = Seq.empty): Seq[Node] = { var listingTableClass = TABLE_CLASS if (fixedWidth) { @@ -248,20 +251,29 @@ private[spark] object UIUtils extends Logging { } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" - val headerRow: Seq[Node] = { - // if none of the headers have "\n" in them - if (headers.forall(!_.contains("\n"))) { - // represent header as simple text - headers.map(h => {h}) + + def getClass(index: Int): String = { + if (index < headerClasses.size) { + headerClasses(index) } else { - // represent header text as list while respecting "\n" - headers.map { case h => - -
    - { h.split("\n").map { case t =>
  • {t}
  • } } -
- - } + "" + } + } + + val newlinesInHeader = headers.exists(_.contains("\n")) + def getHeaderContent(header: String): Seq[Node] = { + if (newlinesInHeader) { +
    + { header.split("\n").map { case t =>
  • {t}
  • } } +
+ } else { + Text(header) + } + } + + val headerRow: Seq[Node] = { + headers.view.zipWithIndex.map { x => + {getHeaderContent(x._1)} } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 0000000000000..e9c755e36f716 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => + + } + +
+

Updated at {UIUtils.formatDate(time)}

+ { + // scalastyle:off +

+ Expand All +

+

+ // scalastyle:on + } +
{dumpRows}
+
+ }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552fd..048fee3ce1ff4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { + { + if (threadDumpEnabled) { + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 689cf02b25b70..dd1c2b78c4094 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** @@ -42,20 +48,21 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToOutputBytes = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() def storageStatusList = storageStatusListener.storageStatusList override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val eid = formatExecutorId(taskStart.taskInfo.executorId) + val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo if (info != null) { - val eid = formatExecutorId(info.executorId) + val eid = info.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration taskEnd.reason match { @@ -72,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead } + metrics.outputMetrics.foreach { outputMetrics => + executorToOutputBytes(eid) = + executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead @@ -84,6 +95,4 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp } } - // This addresses executor ID inconsistencies in the local mode - private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index f0e43fbf70976..fa0f96bff34ff 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -45,6 +45,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr + @@ -77,6 +78,8 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr + val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) - val taskHeaders: Seq[String] = + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host", - "Launch Time", "Duration", "GC Time", "Accumulators") ++ - {if (hasInput) Seq("Input") else Nil} ++ - {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ - {if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++ - Seq("Errors") + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + ("GC Time", TaskDetailsClassNames.GC_TIME), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input", "")) else Nil} ++ + {if (hasOutput) Seq(("Output", "")) else Nil} ++ + {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ + {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + else Nil} ++ + Seq(("Errors", "")) + + val unzipped = taskHeadersAndCssClasses.unzip val taskTable = UIUtils.listingTable( - taskHeaders, taskRow(hasInput, hasShuffleRead, hasShuffleWrite, hasBytesSpilled), tasks) - + unzipped._1, + taskRow(hasAccumulators, hasInput, hasOutput, hasShuffleRead, hasShuffleWrite, + hasBytesSpilled), + tasks, + headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -122,18 +192,48 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { None } else { - val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.resultSerializationTime.toDouble + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { + Distribution(times).get.getQuantiles().map { millis => + + } } - val serializationQuantiles = - +: Distribution(serializationTimes). - get.getQuantiles().map(ms => ) + + val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + +: getFormattedTimeQuantiles(deserializationTimes) val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } - val serviceQuantiles = +: Distribution(serviceTimes).get.getQuantiles() - .map(ms => ) + val serviceQuantiles = +: getFormattedTimeQuantiles(serviceTimes) + + val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.jvmGCTime.toDouble + } + val gcQuantiles = + +: getFormattedTimeQuantiles(gcTimes) + + val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.resultSerializationTime.toDouble + } + val serializationQuantiles = + +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => if (info.gettingResultTime > 0) { @@ -142,76 +242,84 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { 0.0 } } - val gettingResultQuantiles = +: - Distribution(gettingResultTimes).get.getQuantiles().map { millis => - - } + val gettingResultQuantiles = + +: + getFormattedTimeQuantiles(gettingResultTimes) // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - val totalExecutionTime = { - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime).toDouble - } else { - (info.finishTime - info.launchTime).toDouble - } - } - totalExecutionTime - metrics.get.executorRunTime + getSchedulerDelay(info, metrics.get).toDouble } val schedulerDelayTitle = + title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: - Distribution(schedulerDelays).get.getQuantiles().map { millis => - - } + getFormattedTimeQuantiles(schedulerDelays) - def getQuantileCols(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => ) val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles = +: getQuantileCols(inputSizes) + val inputQuantiles = +: getFormattedSizeQuantiles(inputSizes) + + val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + } + val outputQuantiles = +: getFormattedSizeQuantiles(outputSizes) val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadQuantiles = +: - getQuantileCols(shuffleReadSizes) + getFormattedSizeQuantiles(shuffleReadSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles = +: getQuantileCols(shuffleWriteSizes) + val shuffleWriteQuantiles = +: + getFormattedSizeQuantiles(shuffleWriteSizes) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles = +: - getQuantileCols(memoryBytesSpilledSizes) + getFormattedSizeQuantiles(memoryBytesSpilledSizes) val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles = +: - getQuantileCols(diskBytesSpilledSizes) + getFormattedSizeQuantiles(diskBytesSpilledSizes) val listings: Seq[Seq[Node]] = Seq( - serializationQuantiles, - serviceQuantiles, - gettingResultQuantiles, - schedulerDelayQuantiles, - if (hasInput) inputQuantiles else Nil, - if (hasShuffleRead) shuffleReadQuantiles else Nil, - if (hasShuffleWrite) shuffleWriteQuantiles else Nil, - if (hasBytesSpilled) memoryBytesSpilledQuantiles else Nil, - if (hasBytesSpilled) diskBytesSpilledQuantiles else Nil) + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + if (hasInput) {inputQuantiles} else Nil, + if (hasOutput) {outputQuantiles} else Nil, + if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", "Max") - def quantileRow(data: Seq[Node]): Seq[Node] = {data} - Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) + Some(UIUtils.listingTable( + quantileHeaders, identity[Seq[Node]], listings, fixedWidth = true)) } val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) @@ -221,6 +329,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val content = summary ++ + showAdditionalMetrics ++

Summary Metrics for {numCompleted} Completed Tasks

++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++

Aggregated Metrics by Executor

++ executorTable.toNodeSeq ++ @@ -232,7 +341,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } def taskRow( + hasAccumulators: Boolean, hasInput: Boolean, + hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = { @@ -241,8 +352,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = info.gettingResultTime + + val maybeAccumulators = info.accumulables + val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") @@ -250,6 +367,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -287,25 +410,40 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { - + + - - + {if (hasAccumulators) { + + }} {if (hasInput) { }} + {if (hasOutput) { + + }} {if (hasShuffleRead) { }} - + {errorMessageCell(errorMessage)} } } + + private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { + val error = errorMessage.getOrElse("") + val isMultiline = error.indexOf('\n') >= 0 + // Display the first line by default + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + error.substring(0, error.indexOf('\n')) + } else { + error + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + + } + + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime) + } else { + (info.finishTime - info.launchTime) + } + } + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + totalExecutionTime - metrics.executorRunTime - executorOverhead + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 4ee7f08ab47a2..eae542df85d08 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,6 +22,8 @@ import scala.xml.Text import java.util.Date +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -43,6 +45,7 @@ private[ui] class StageTableBase( + + + + + + @@ -359,6 +377,16 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. + + + + +
Thread Dump
{Utils.bytesToString(info.totalShuffleWrite)} + Thread Dump +
Failed Tasks Succeeded Tasks InputOutput Shuffle Read Shuffle Write Shuffle Spill (Memory){v.succeededTasks} {Utils.bytesToString(v.inputBytes)} + {Utils.bytesToString(v.outputBytes)} {Utils.bytesToString(v.shuffleRead)} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b5207360510dd..8bbde51e1801c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] + + // Number of completed and failed stages, may not actually equal to completedStages.size and + // failedStages.size respectively due to completedStage and failedStages only maintain the latest + // part of the stages, the earlier ones will be removed when there are too many stages for + // memory sake. + var numCompletedStages = 0 + var numFailedStages = 0 // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() @@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage + numCompletedStages += 1 trimIfNecessary(completedStages) } else { failedStages += stage + numFailedStages += 1 trimIfNecessary(failedStages) } } @@ -250,6 +259,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val outputBytesDelta = + (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) + stageData.outputBytes += outputBytesDelta + execSummary.outputBytes += outputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 6e718eecdd52a..83a7898071c9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") listener.synchronized { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages val now = System.currentTimeMillis val activeStagesTable = @@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
  • Completed Stages: - {completedStages.size} + {numCompletedStages}
  • Failed Stages: - {failedStages.size} + {numFailedStages}
  • @@ -86,9 +88,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") }} ++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq ++ -

    Completed Stages ({completedStages.size})

    ++ +

    Completed Stages ({numCompletedStages})

    ++ completedStagesTable.toNodeSeq ++ -

    Failed Stages ({failedStages.size})

    ++ +

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq UIUtils.headerSparkPage("Spark Stages", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2414e4c65237e..16bc3f6c18d09 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,10 +22,13 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} -import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { @@ -52,12 +55,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val numCompleted = tasks.count(_.taskInfo.finished) val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables + val hasAccumulators = accumulables.size > 0 val hasInput = stageData.inputBytes > 0 + val hasOutput = stageData.outputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 - // scalastyle:off val summary =
      @@ -65,55 +69,121 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Total task time across all tasks: {UIUtils.formatDuration(stageData.executorRunTime)} - {if (hasInput) + {if (hasInput) {
    • Input: {Utils.bytesToString(stageData.inputBytes)}
    • - } - {if (hasShuffleRead) + }} + {if (hasOutput) { +
    • + Output: + {Utils.bytesToString(stageData.outputBytes)} +
    • + }} + {if (hasShuffleRead) {
    • Shuffle read: {Utils.bytesToString(stageData.shuffleReadBytes)}
    • - } - {if (hasShuffleWrite) + }} + {if (hasShuffleWrite) {
    • Shuffle write: {Utils.bytesToString(stageData.shuffleWriteBytes)}
    • - } - {if (hasBytesSpilled) -
    • - Shuffle spill (memory): - {Utils.bytesToString(stageData.memoryBytesSpilled)} -
    • -
    • - Shuffle spill (disk): - {Utils.bytesToString(stageData.diskBytesSpilled)} -
    • - } + }} + {if (hasBytesSpilled) { +
    • + Shuffle spill (memory): + {Utils.bytesToString(stageData.memoryBytesSpilled)} +
    • +
    • + Shuffle spill (disk): + {Utils.bytesToString(stageData.diskBytesSpilled)} +
    • + }}
    - // scalastyle:on + + val showAdditionalMetrics = +
    + + + Show additional metrics + + +
    + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo) =
    {acc.name}{acc.value}
    {UIUtils.formatDuration(millis.toLong)}Result serialization time{UIUtils.formatDuration(ms.toLong)} + + Task Deserialization Time + + Duration{UIUtils.formatDuration(ms.toLong)}Duration + GC Time + + + + Result Serialization Time + + Time spent fetching task results{UIUtils.formatDuration(millis.toLong)} + + Getting Result Time + + Scheduler delay{UIUtils.formatDuration(millis.toLong)}{Utils.bytesToString(d.toLong)}InputInputOutputShuffle Read (Remote)Shuffle WriteShuffle WriteShuffle spill (memory)Shuffle spill (disk)
    {formatDuration} + + {UIUtils.formatDuration(schedulerDelay.toLong)} + + {UIUtils.formatDuration(taskDeserializationTime.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - {Unparsed( - info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
    ") - )} +
    + {UIUtils.formatDuration(serializationTime)} + {Unparsed(accumulatorsReadable.mkString("
    "))} +
    {inputReadable} + {outputReadable} + {shuffleReadReadable} @@ -327,10 +465,47 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {diskBytesSpilledReadable} - {errorMessage.map { e =>
    {e}
    }.getOrElse("")} -
    {errorSummary}{details}Duration Tasks: Succeeded/Total InputOutput Shuffle Read - # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` profile to your existing build options. By default Spark will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. NOTE: currently the JDBC server is only -supported for Hive 0.12.0. +the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package @@ -121,8 +118,8 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-0.12.0 clean package - mvn -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package + mvn -Pyarn -Phadoop-2.3 -Phive test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -185,16 +182,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test # Speeding up Compilation with Zinc diff --git a/docs/configuration.md b/docs/configuration.md index 3007706a2586e..f0b396e21f198 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,16 +21,22 @@ application. These properties can be set directly on a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your `SparkContext`. `SparkConf` allows you to configure some of the common properties (e.g. master URL and application name), as well as arbitrary key-value pairs through the -`set()` method. For example, we could initialize an application as follows: +`set()` method. For example, we could initialize an application with two threads as follows: + +Note that we run with local[2], meaning two threads - which represents "minimal" parallelism, +which can help detect bugs that only exist when we run in a distributed context. {% highlight scala %} val conf = new SparkConf() - .setMaster("local") + .setMaster("local[2]") .setAppName("CountingSheep") .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} +Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually +require one to prevent any sort of starvation issues. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -111,6 +117,18 @@ of the most common options to set are: (e.g. 512m, 2g).
    spark.driver.maxResultSize1g + Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). + Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size + is above this limit. + Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory + and memory overhead of objects in JVM). Setting a proper limit can protect the driver from + out-of-memory errors. +
    spark.serializer org.apache.spark.serializer.
    JavaSerializer
    spark.shuffle.blockTransferServicenetty + Implementation to use for transferring shuffle and cached blocks between executors. There + are two implementations available: netty and nio. Netty-based + block transfer is intended to be simpler but equally efficient and is the default option + starting in 1.2. +
    #### Spark UI @@ -534,6 +562,9 @@ Apart from these, the following properties are also available, and may be useful spark.default.parallelism + For distributed shuffle operations like reduceByKey and join, the + largest number of partitions in a parent RDD. For operations like parallelize + with no parent RDDs, it depends on the cluster manager:
    • Local mode: number of cores on the local machine
    • Mesos fine grained mode: 8
    • @@ -541,8 +572,8 @@ Apart from these, the following properties are also available, and may be useful
    - Default number of tasks to use across the cluster for distributed shuffle operations - (groupByKey, reduceByKey, etc) when not set by user. + Default number of partitions in RDDs returned by transformations like join, + reduceByKey, and parallelize when not set by user. diff --git a/docs/index.md b/docs/index.md index edd622ec90f64..171d6ddad62f3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -112,6 +112,7 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) +* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 7978e934fb36b..c696ae9c8e8c8 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. -## Examples +### Examples
    @@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap section of the Spark Quick Start guide. Be sure to also include *spark-mllib* to your build file as a dependency. + +## Streaming clustering + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +all points to their nearest cluster, compute new cluster centers, then update each cluster using: + +`\begin{equation} + c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} +\end{equation}` +`\begin{equation} + n_{t+1} = n_t + m_t +\end{equation}` + +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. + +The decay can be specified using a `halfLife` parameter, which determines the +correct decay factor `a` such that, for data acquired +at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. +The unit of time can be specified either as `batches` or `points` and the update rule +will be adjusted accordingly. + +### Examples + +This example shows how to estimate clusters on streaming data. + +
    + +
    + +First we import the neccessary classes. + +{% highlight scala %} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans + +{% endhighlight %} + +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight scala %} + +val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) + +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight scala %} + +val numDimensions = 3 +val numClusters = 2 +val model = new StreamingKMeans() + .setK(numClusters) + .setDecayFactor(1.0) + .setRandomCenters(numDimensions, 0.0) + +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight scala %} + +model.trainOn(trainingData) +model.predictOnValues(testData).print() + +ssc.start() +ssc.awaitTermination() + +{% endhighlight %} + +As you add new text files with data the cluster centers will update. Each training +point should be formatted as `[x1, x2, x3]`, and each test data point +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` +you will see predictions. With new data, the cluster centers will change! + +
    + +
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 886d71df474bc..197bc77d506c6 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -203,6 +203,23 @@ for((synonym, cosineSimilarity) <- synonyms) { } {% endhighlight %}
    +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.feature import Word2Vec + +sc = SparkContext(appName='Word2Vec') +inp = sc.textFile("text8_lines").map(lambda row: row.split(" ")) + +word2vec = Word2Vec() +model = word2vec.fit(inp) + +synonyms = model.findSynonyms('china', 40) + +for word, cosine_distance in synonyms: + print "{}: {}".format(word, cosine_distance) +{% endhighlight %} +
    ## StandardScaler diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 7f9d4c6563944..d5b044d94fdd7 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel = return new Tuple2(model.predict(p.features()), p.label()); } }); -double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { - return pl._1() == pl._2(); + return pl._1().equals(pl._2()); } - }).count() / test.count(); + }).count() / (double) test.count(); {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 10a5131c07414..ca8c29218f52d 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) { {% endhighlight %} +
    +[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors, Matrices +from pyspark.mllib.regresssion import LabeledPoint +from pyspark.mllib.stat import Statistics + +sc = SparkContext() + +vec = Vectors.dense(...) # a vector composed of the frequencies of events + +# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +# the test runs against a uniform distribution. +goodnessOfFitTestResult = Statistics.chiSqTest(vec) +print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + +mat = Matrices.dense(...) # a contingency matrix + +# conduct Pearson's independence test on the input contingency matrix +independenceTestResult = Statistics.chiSqTest(mat) +print independenceTestResult # summary of the test including the p-value, degrees of freedom... + +obs = sc.parallelize(...) # LabeledPoint(feature, label) . + +# The contingency table is constructed from an RDD of LabeledPoint and used to conduct +# the independence test. Returns an array containing the ChiSquaredTestResult for every feature +# against the label. +featureTestResults = Statistics.chiSqTest(obs) + +for i, result in enumerate(featureTestResults): + print "Column $d:" % (i + 1) + print result +{% endhighlight %} +
    + ## Random data generation diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 695813a2ba881..2f7e4981e5bb9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -4,7 +4,7 @@ title: Running Spark on YARN --- Support for running on [YARN (Hadoop -NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html) +NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. # Preparations diff --git a/docs/security.md b/docs/security.md index ec0523184d665..1e206a139fb72 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d4ade939c3a6e..ffcce2c588879 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or spark.sql.parquet.cacheMetadata - false + true Turns on caching of Parquet schema metadata. Can speed up querying of static data. spark.sql.parquet.compression.codec - snappy + gzip Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. + + spark.sql.hive.convertMetastoreParquet + true + + When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in + support. + + ## JSON Datasets @@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL Property NameDefaultMeaning spark.sql.inMemoryColumnarStorage.compressed - false + true When set to true Spark SQL will automatically select a compression codec for each column based on statistics of the data. @@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL spark.sql.inMemoryColumnarStorage.batchSize - 1000 + 10000 Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar Property NameDefaultMeaning spark.sql.autoBroadcastJoinThreshold - 10000 + 10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently @@ -1051,7 +1059,6 @@ in Hive deployments. **Major Hive Features** -* Spark SQL does not currently support inserting to tables using dynamic partitioning. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8bbba88b31978..44a1f3ad7560b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -68,7 +68,9 @@ import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// Create a local StreamingContext with two working thread and batch interval of 1 second +// Create a local StreamingContext with two working thread and batch interval of 1 second. +// The master requires 2 cores to prevent from a starvation scenario. + val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} @@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver]( A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are: -##### Points to remember: +##### Points to remember {:.no_toc} -- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. -- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data. - +- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. +- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor. +Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally. +See [Spark Properties] (configuration.html#spark-properties.html). + ### Basic Sources {:.no_toc} diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 31f9771223e51..4aa908242eeaa 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -18,5 +18,9 @@ # limitations under the License. # -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" +# Preserve the user's CWD so that relative paths are passed correctly to +#+ the underlying Python script. +SPARK_EC2_DIR="$(dirname $0)" + +PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \ + python "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0d6b82b4944f3..a5396c2375915 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -40,9 +40,11 @@ from boto import ec2 DEFAULT_SPARK_VERSION = "1.1.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) +MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" +AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) class UsageError(Exception): @@ -583,10 +585,23 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") + ssh( + host=master, + opts=opts, + command="rm -rf spark-ec2" + + " && " + + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + ) print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) + deploy_files( + conn=conn, + root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", + opts=opts, + master_nodes=master_nodes, + slave_nodes=slave_nodes, + modules=modules + ) print "Running setup on master..." setup_spark_cluster(master, opts) @@ -723,6 +738,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. +# +# root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): active_master = master_nodes[0].public_dns_name diff --git a/examples/pom.xml b/examples/pom.xml index bc3291803c324..910eb55308b9d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -50,6 +50,30 @@
    + + hbase-hadoop2 + + + hbase.profile + hadoop2 + + + + 0.98.7-hadoop2 + + + + hbase-hadoop1 + + + !hbase.profile + + + + 0.98.7-hadoop1 + + + @@ -120,37 +144,122 @@ spark-streaming-mqtt_${scala.binary.version} ${project.version} - - org.apache.hbase - hbase - ${hbase.version} - - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - - - org.eclipse.jetty jetty-server + + org.apache.hbase + hbase-testing-util + ${hbase.version} + + + org.jruby + jruby-complete + + + + + org.apache.hbase + hbase-protocol + ${hbase.version} + + + org.apache.hbase + hbase-common + ${hbase.version} + + + org.apache.hbase + hbase-client + ${hbase.version} + + + io.netty + netty + + + + + org.apache.hbase + hbase-server + ${hbase.version} + + + org.apache.hadoop + hadoop-core + + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-mapreduce-client-jobclient + + + org.apache.hadoop + hadoop-mapreduce-client-core + + + org.apache.hadoop + hadoop-auth + + + org.apache.hadoop + hadoop-annotations + + + org.apache.hadoop + hadoop-hdfs + + + org.apache.hbase + hbase-hadoop1-compat + + + org.apache.commons + commons-math + + + com.sun.jersey + jersey-core + + + org.slf4j + slf4j-api + + + com.sun.jersey + jersey-server + + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + + commons-io + commons-io + + + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + test-jar + test + com.twitter algebird-core_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index 6c177de359b60..31a79ddd3fff1 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -30,12 +30,25 @@ /** * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ public final class JavaHdfsLR { private static final int D = 10; // Number of dimensions private static final Random rand = new Random(42); + static void showWarning() { + String warning = "WARN: This is a naive implementation of Logistic Regression " + + "and is given as an example!\n" + + "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + + "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + + "for more conventional use."; + System.err.println(warning); + } + static class DataPoint implements Serializable { DataPoint(double[] x, double y) { this.x = x; @@ -109,6 +122,8 @@ public static void main(String[] args) { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(args[0]); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index c22506491fbff..a5db8accdf138 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -45,10 +45,21 @@ * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ public final class JavaPageRank { private static final Pattern SPACES = Pattern.compile("\\s+"); + static void showWarning() { + String warning = "WARN: This is a naive implementation of PageRank " + + "and is given as an example! \n" + + "Please use the PageRank implementation found in " + + "org.apache.spark.graphx.lib.PageRank for more conventional use."; + System.err.println(warning); + } + private static class Sum implements Function2 { @Override public Double call(Double a, Double b) { @@ -62,6 +73,8 @@ public static void main(String[] args) throws Exception { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java new file mode 100644 index 0000000000000..1af2067b2b929 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTrees { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTrees " + + " "); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.weakLearnerParams().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 981bc4f0613a9..99df259b4e8e6 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -70,7 +70,7 @@ public static void main(String[] args) { // Create a input stream with the custom receiver on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.receiverStream( - new JavaCustomReceiver(args[1], Integer.parseInt(args[2]))); + new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index 45bcedebb4117..3e9f0f4b8f127 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; @@ -35,8 +35,9 @@ /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * * Usage: JavaNetworkWordCount - * and describe the TCP server that Spark Streaming would connect to receive data. + * and describe the TCP server that Spark Streaming would connect to receive data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` @@ -56,7 +57,7 @@ public static void main(String[] args) { // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaNetworkWordCount"); - JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); // Create a JavaReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java new file mode 100644 index 0000000000000..bceda97f058ea --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.regex.Pattern; + +import scala.Tuple2; +import com.google.common.collect.Lists; +import com.google.common.io.Files; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; + +/** + * Counts words in text encoded with UTF8 received from the network every second. + * + * Usage: JavaRecoverableNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. directory to HDFS-compatible file system which checkpoint data + * file to which the word counts will be appended + * + * and must be absolute paths + * + * To run this on your local machine, you need to first run a Netcat server + * + * `$ nc -lk 9999` + * + * and run the example as + * + * `$ ./bin/run-example org.apache.spark.examples.streaming.JavaRecoverableNetworkWordCount \ + * localhost 9999 ~/checkpoint/ ~/out` + * + * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + * the checkpoint data. + * + * Refer to the online documentation for more details. + */ +public final class JavaRecoverableNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + private static JavaStreamingContext createContext(String ip, + int port, + String checkpointDirectory, + String outputPath) { + + // If you do not see this printed, that means the StreamingContext has been loaded + // from the new checkpoint + System.out.println("Creating new context"); + final File outputFile = new File(outputPath); + if (outputFile.exists()) { + outputFile.delete(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaRecoverableNetworkWordCount"); + // Create the context with a 1 second batch size + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint(checkpointDirectory); + + // Create a socket stream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + String counts = "Counts at time " + time + " " + rdd.collect(); + System.out.println(counts); + System.out.println("Appending to " + outputFile.getAbsolutePath()); + Files.append(counts + "\n", outputFile, Charset.defaultCharset()); + return null; + } + }); + + return ssc; + } + + public static void main(String[] args) { + if (args.length != 4) { + System.err.println("You arguments were " + Arrays.asList(args)); + System.err.println( + "Usage: JavaRecoverableNetworkWordCount \n" + + " . and describe the TCP server that Spark\n" + + " Streaming would connect to receive data. directory to\n" + + " HDFS-compatible file system which checkpoint data file to which\n" + + " the word counts will be appended\n" + + "\n" + + "In local mode, should be 'local[n]' with n > 1\n" + + "Both and must be absolute paths"); + System.exit(1); + } + + final String ip = args[0]; + final int port = Integer.parseInt(args[1]); + final String checkpointDirectory = args[2]; + final String outputPath = args[3]; + JavaStreamingContextFactory factory = new JavaStreamingContextFactory() { + @Override + public JavaStreamingContext create() { + return createContext(ip, port, checkpointDirectory, outputPath); + } + }; + JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, factory); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 0000000000000..540dae785f6ea --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py new file mode 100644 index 0000000000000..99fef4276a369 --- /dev/null +++ b/examples/src/main/python/mllib/word2vec.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This example uses text8 file from http://mattmahoney.net/dc/text8.zip +# The file was downloadded, unziped and split into multiple lines using +# +# wget http://mattmahoney.net/dc/text8.zip +# unzip text8.zip +# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines +# This was done so that the example can be run in local mode + + +import sys + +from pyspark import SparkContext +from pyspark.mllib.feature import Word2Vec + +USAGE = ("bin/spark-submit --driver-memory 4g " + "examples/src/main/python/mllib/word2vec.py text8_lines") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print USAGE + sys.exit("Argument for file not provided") + file_path = sys.argv[1] + sc = SparkContext(appName='Word2Vec') + inp = sc.textFile(file_path).map(lambda row: row.split(" ")) + + word2vec = Word2Vec() + model = word2vec.fit(inp) + + synonyms = model.findSynonyms('china', 40) + + for word, cosine_distance in synonyms: + print "{}: {}".format(word, cosine_distance) + sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index b539c4128cdcc..a5f25d78c1146 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +This is an example implementation of PageRank. For more conventional use, +Please refer to PageRank implementation provided by graphx +""" + import re import sys from operator import add @@ -40,6 +45,9 @@ def parseNeighbors(urls): print >> sys.stderr, "Usage: pagerank " exit(-1) + print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""" + # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 931faac5463c4..ac2ea35bbd0e0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalFileLR { val D = 10 // Numer of dimensions @@ -41,7 +42,8 @@ object LocalFileLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 2d75b9d2590f8..92a683ad57ea1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalLR { val N = 10000 // Number of data points @@ -48,7 +49,8 @@ object LocalLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 3258510894372..9099c2fcc90b3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -32,7 +32,8 @@ import org.apache.spark.scheduler.InputFormatInfo * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkHdfsLR { val D = 10 // Numer of dimensions @@ -54,7 +55,8 @@ object SparkHdfsLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index fc23308fc4adf..257a7d29f922a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -30,7 +30,8 @@ import org.apache.spark._ * Usage: SparkLR [slices] * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkLR { val N = 10000 // Number of data points @@ -53,7 +54,8 @@ object SparkLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 4c7e006da0618..8d092b6506d33 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -28,13 +28,28 @@ import org.apache.spark.{SparkConf, SparkContext} * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ object SparkPageRank { + + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of PageRank and is given as an example! + |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { if (args.length < 1) { System.err.println("Usage: SparkPageRank ") System.exit(1) } + + showWarning() + val sparkConf = new SparkConf().setAppName("PageRank") val iters = if (args.length > 0) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 96d13612e46dd..4393b99e636b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -32,11 +32,24 @@ import org.apache.spark.storage.StorageLevel /** * Logistic regression based classification. * This example uses Tachyon to persist rdds during computation. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkTachyonHdfsLR { val D = 10 // Numer of dimensions val rand = new Random(42) + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |for more conventional use. + """.stripMargin) + } + case class DataPoint(x: Vector[Double], y: Double) def parsePoint(line: String): DataPoint = { @@ -51,6 +64,9 @@ object SparkTachyonHdfsLR { } def main(args: Array[String]) { + + showWarning() + val inputPath = args(0) val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") val conf = new Configuration() diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index d70d93608a57c..828cffb01ca1e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -77,7 +77,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -110,7 +110,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -131,7 +131,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")")) val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) // TriangleCount requires the graph to be partitioned diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..f8d83f4ec7327 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 0890e6263e165..63f02cf7b98b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -62,7 +62,10 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) extends AbstractParams[Params] + fracTest: Double = 0.2, + useNodeIdCache: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -102,6 +105,21 @@ object DecisionTreeRunner { .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("useNodeIdCache") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.useNodeIdCache}") + .action((x, c) => c.copy(useNodeIdCache = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) opt[String]("testInput") .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") @@ -136,20 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") - val sc = new SparkContext(conf) - - println(s"DecisionTreeRunner with parameters:\n$params") - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -187,14 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. val numFeatures = examples.take(1)(0).features.size - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -211,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -236,7 +278,10 @@ object DecisionTreeRunner { maxBins = params.maxBins, numClassesForClassification = numClasses, minInstancesPerNode = params.minInstancesPerNode, - minInfoGain = params.minInfoGain) + minInfoGain = params.minInfoGain, + useNodeIdCache = params.useNodeIdCache, + checkpointDir = params.checkpointDir, + checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) @@ -317,7 +362,9 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + tree: WeightedEnsembleModel, + data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..9b6db01448be0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTrees { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTrees with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 8796c28db8a66..91a0a860d6c71 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -106,9 +106,11 @@ object MovieLensALS { Logger.getRootLogger.setLevel(Level.WARN) + val implicitPrefs = params.implicitPrefs + val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - if (params.implicitPrefs) { + if (implicitPrefs) { /* * MovieLens ratings are on a scale of 1-5: * 5: Must see diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala new file mode 100644 index 0000000000000..33e5760aed997 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Estimate clusters on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the training text files must be vector data in the form + * `[x1,x2,x3,...,xn]` + * Where n is the number of dimensions. + * + * The rows of the test text files must be labeled data in the form + * `(y,[x1,x2,x3,...,xn])` + * Where y is some identifier. n must be the same for train and test. + * + * Usage: StreamingKmeans + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call: + * $ bin/run-example \ + * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2 + * + * As you add text files to `trainingDir` the clusters will continuously update. + * Anytime you add text files to `testDir`, you'll see predicted labels using the current model. + * + */ +object StreamingKMeans { + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println( + "Usage: StreamingKMeans " + + " ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingKMeans() + .setK(args(3).toInt) + .setDecayFactor(1.0) + .setRandomCenters(args(4).toInt, 0.0) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 6af3a0f33efc2..19427e629f76d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -31,15 +31,13 @@ import org.apache.spark.util.IntParam /** * Counts words in text encoded with UTF8 received from the network every second. * - * Usage: NetworkWordCount + * Usage: RecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. directory to HDFS-compatible file system which checkpoint data * file to which the word counts will be appended * - * In local mode, should be 'local[n]' with n > 1 * and must be absolute paths * - * * To run this on your local machine, you need to first run a Netcat server * * `$ nc -lk 9999` @@ -54,22 +52,11 @@ import org.apache.spark.util.IntParam * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from * the checkpoint data. * - * To run this example in a local standalone cluster with automatic driver recovery, - * - * `$ bin/spark-class org.apache.spark.deploy.Client -s launch \ - * \ - * org.apache.spark.examples.streaming.RecoverableNetworkWordCount \ - * localhost 9999 ~/checkpoint ~/out` - * - * would typically be - * /examples/target/scala-XX/spark-examples....jar - * * Refer to the online documentation for more details. */ - object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint @@ -79,6 +66,7 @@ object RecoverableNetworkWordCount { val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) + ssc.checkpoint(checkpointDirectory) // Create a socket stream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') @@ -114,7 +102,7 @@ object RecoverableNetworkWordCount { val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args val ssc = StreamingContext.getOrCreate(checkpointDirectory, () => { - createContext(ip, port, outputPath) + createContext(ip, port, outputPath, checkpointDirectory) }) ssc.start() ssc.awaitTermination() diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..4411d6e20c52a --- /dev/null +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file streaming/target/unit-tests.log +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN + diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index a2b2cc6149d95..650b2fbe1c142 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -159,6 +159,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) channelContext.putAll(overrides) + channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) val sink = new SparkSink() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 32a19787a28e1..475026e8eb140 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -145,11 +145,16 @@ class FlumePollingStreamSuite extends TestSuiteBase { outputStream.register() ssc.start() - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - sink.stop() - channel.stop() + try { + writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + } finally { + sink.stop() + sink2.stop() + channel.stop() + channel2.stop() + } } def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index e20e2c8f26991..28ac5929df44a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -26,8 +26,6 @@ import java.util.concurrent.Executors import kafka.consumer._ import kafka.serializer.Decoder import kafka.utils.VerifiableProperties -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -97,12 +95,6 @@ class KafkaReceiver[ consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + zkConnect) - // When auto.offset.reset is defined, it is our responsibility to try and whack the - // consumer group zk node. - if (kafkaParams.contains("auto.offset.reset")) { - tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) - } - val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] @@ -139,26 +131,4 @@ class KafkaReceiver[ } } } - - // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This - // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper. - // - // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied - // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to - // 'smallest'/'largest': - // scalastyle:off - // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala - // scalastyle:on - private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary Zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) - try { - zk.deleteRecursive(dir) - } catch { - case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e) - } finally { - zk.close() - } - } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 48668f763e41e..ec812e1ef3b04 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.kafka -import scala.reflect.ClassTag -import scala.collection.JavaConversions._ - import java.lang.{Integer => JInt} import java.util.{Map => JMap} +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + import kafka.serializer.{Decoder, StringDecoder} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext, JavaPairDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} - +import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object KafkaUtils { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 5bcb96b136ed7..5267560b3e5ce 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -82,12 +82,17 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( this } - /** Persists the vertex partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ override def cache(): this.type = { partitionsRDD.persist(targetStorageLevel) this } + /** The number of edges in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_._2.size.toLong).reduce(_ + _) + } + private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index f4c79365b16da..4933aecba1286 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -48,7 +48,8 @@ object GraphLoader extends Logging { * @param path the path to the file (e.g., /home/data/file or hdfs://file) * @param canonicalOrientation whether to orient edges in the positive * direction - * @param minEdgePartitions the number of partitions for the edge RDD + * @param numEdgePartitions the number of partitions for the edge RDD + * Setting this value to -1 will use the default parallelism. * @param edgeStorageLevel the desired storage level for the edge partitions * @param vertexStorageLevel the desired storage level for the vertex partitions */ @@ -56,7 +57,7 @@ object GraphLoader extends Logging { sc: SparkContext, path: String, canonicalOrientation: Boolean = false, - minEdgePartitions: Int = 1, + numEdgePartitions: Int = -1, edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY, vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) : Graph[Int, Int] = @@ -64,7 +65,12 @@ object GraphLoader extends Logging { val startTime = System.currentTimeMillis // Parse the edge data table directly into edge partitions - val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions) + val lines = + if (numEdgePartitions > 0) { + sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions) + } else { + sc.textFile(path) + } val edges = lines.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[Int, Int] iter.foreach { line => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 2c8b245955d12..12216d9d33d66 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -27,8 +27,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock -import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._ -import org.apache.spark.graphx.impl.VertexRDDFunctions._ /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -233,7 +231,7 @@ class VertexRDD[@specialized VD: ClassTag]( case _ => this.withPartitionsRDD[VD3]( partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) } ) @@ -277,7 +275,7 @@ class VertexRDD[@specialized VD: ClassTag]( case _ => this.withPartitionsRDD( partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) } ) @@ -297,7 +295,7 @@ class VertexRDD[@specialized VD: ClassTag]( */ def aggregateUsingIndex[VD2: ClassTag]( messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { - val shuffled = messages.copartitionWithVertices(this.partitioner.get) + val shuffled = messages.partitionBy(this.partitioner.get) val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) } @@ -371,7 +369,7 @@ object VertexRDD { def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val vertexPartitions = vPartitioned.mapPartitions( iter => Iterator(ShippableVertexPartition(iter)), @@ -412,7 +410,7 @@ object VertexRDD { ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { @@ -454,7 +452,7 @@ object VertexRDD { .setName("VertexRDD.createRoutingTables - vid2pid (aggregation)") val numEdgePartitions = edges.partitions.size - vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions( + vid2pid.partitionBy(vertexPartitioner).mapPartitions( iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)), preservesPartitioning = true) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb991515..2b6137be25547 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { srcIds(i) = edgeArray(i).srcId diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala deleted file mode 100644 index 714f3b81c9dad..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.implicitConversions -import scala.reflect.{classTag, ClassTag} - -import org.apache.spark.Partitioner -import org.apache.spark.graphx.{PartitionID, VertexId} -import org.apache.spark.rdd.{ShuffledRDD, RDD} - - -private[graphx] -class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { - def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) - - // Set a custom serializer if the data is of int or double type. - if (classTag[VD] == ClassTag.Int) { - rdd.setSerializer(new IntAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Long) { - rdd.setSerializer(new LongAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Double) { - rdd.setSerializer(new DoubleAggMsgSerializer) - } - rdd - } -} - -private[graphx] -object VertexRDDFunctions { - implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = { - new VertexRDDFunctions(rdd) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index b27485953f719..7a7fa91aadfe1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -29,24 +29,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage -private[graphx] -class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { - /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ - def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int]( - self, partitioner).setSerializer(new RoutingTableMessageSerializer) - } -} - -private[graphx] -object RoutingTableMessageRDDFunctions { - import scala.language.implicitConversions - - implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = { - new RoutingTableMessageRDDFunctions(rdd) - } -} - private[graphx] object RoutingTablePartition { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala deleted file mode 100644 index 3909efcdfc993..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.existentials - -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.serializer._ - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage - -private[graphx] -class RoutingTableMessageSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream): SerializationStream = - new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg._1, optimizePositive = false) - writeInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream): DeserializationStream = - new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readInt() - (a, b).asInstanceOf[T] - } - } - } -} - -private[graphx] -class VertexIdMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, _)] - writeVarLong(msg._1, optimizePositive = false) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - (readVarLong(optimizePositive = false), null).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Int]. */ -private[graphx] -class IntAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Int)] - writeVarLong(msg._1, optimizePositive = false) - writeUnsignedVarInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Long]. */ -private[graphx] -class LongAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Long)] - writeVarLong(msg._1, optimizePositive = false) - writeVarLong(msg._2, optimizePositive = true) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readVarLong(optimizePositive = true) - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Double]. */ -private[graphx] -class DoubleAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Double)] - writeVarLong(msg._1, optimizePositive = false) - writeDouble(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readDouble() - (a, b).asInstanceOf[T] - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Helper classes to shorten the implementation of those special serializers. -//////////////////////////////////////////////////////////////////////////////// - -private[graphx] -abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { - // The implementation should override this one. - def writeObject[T: ClassTag](t: T): SerializationStream - - def writeInt(v: Int) { - s.write(v >> 24) - s.write(v >> 16) - s.write(v >> 8) - s.write(v) - } - - def writeUnsignedVarInt(value: Int) { - if ((value >>> 7) == 0) { - s.write(value.toInt) - } else if ((value >>> 14) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7) - } else if ((value >>> 21) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14) - } else if ((value >>> 28) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21) - } else { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21 | 0x80) - s.write(value >>> 28) - } - } - - def writeVarLong(value: Long, optimizePositive: Boolean) { - val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value - if ((v >>> 7) == 0) { - s.write(v.toInt) - } else if ((v >>> 14) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7).toInt) - } else if ((v >>> 21) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14).toInt) - } else if ((v >>> 28) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21).toInt) - } else if ((v >>> 35) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28).toInt) - } else if ((v >>> 42) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35).toInt) - } else if ((v >>> 49) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42).toInt) - } else if ((v >>> 56) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49).toInt) - } else { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49 | 0x80).toInt) - s.write((v >>> 56).toInt) - } - } - - def writeLong(v: Long) { - s.write((v >>> 56).toInt) - s.write((v >>> 48).toInt) - s.write((v >>> 40).toInt) - s.write((v >>> 32).toInt) - s.write((v >>> 24).toInt) - s.write((v >>> 16).toInt) - s.write((v >>> 8).toInt) - s.write(v.toInt) - } - - def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v)) - - override def flush(): Unit = s.flush() - - override def close(): Unit = s.close() -} - -private[graphx] -abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { - // The implementation should override this one. - def readObject[T: ClassTag](): T - - def readInt(): Int = { - val first = s.read() - if (first < 0) throw new EOFException - (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) - } - - def readUnsignedVarInt(): Int = { - var value: Int = 0 - var i: Int = 0 - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b: Int = readOrThrow() - while ((b & 0x80) != 0) { - value |= (b & 0x7F) << i - i += 7 - if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long") - b = readOrThrow() - } - value | (b << i) - } - - def readVarLong(optimizePositive: Boolean): Long = { - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b = readOrThrow() - var ret: Long = b & 0x7F - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 7 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 14 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 21 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 28 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 35 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 42 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 49 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= b.toLong << 56 - } - } - } - } - } - } - } - } - if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret - } - - def readLong(): Long = { - val first = s.read() - if (first < 0) throw new EOFException() - (first.toLong << 56) | - (s.read() & 0xFF).toLong << 48 | - (s.read() & 0xFF).toLong << 40 | - (s.read() & 0xFF).toLong << 32 | - (s.read() & 0xFF).toLong << 24 | - (s.read() & 0xFF) << 16 | - (s.read() & 0xFF) << 8 | - (s.read() & 0xFF) - } - - def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong()) - - override def close(): Unit = s.close() -} - -private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance { - - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException - - // The implementation should override the following two. - override def serializeStream(s: OutputStream): SerializationStream - override def deserializeStream(s: InputStream): DeserializationStream -} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala deleted file mode 100644 index 864cb1fdf0022..0000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} - -import scala.util.Random -import scala.reflect.ClassTag - -import org.scalatest.FunSuite - -import org.apache.spark._ -import org.apache.spark.graphx.impl._ -import org.apache.spark.serializer.SerializationStream - - -class SerializerSuite extends FunSuite with LocalSparkContext { - - test("IntAggMsgSerializer") { - val outMsg = (4: VertexId, 5) - val bout = new ByteArrayOutputStream - val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Int) = inStrm.readObject() - val inMsg2: (VertexId, Int) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("LongAggMsgSerializer") { - val outMsg = (4: VertexId, 1L << 32) - val bout = new ByteArrayOutputStream - val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Long) = inStrm.readObject() - val inMsg2: (VertexId, Long) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("DoubleAggMsgSerializer") { - val outMsg = (4: VertexId, 5.0) - val bout = new ByteArrayOutputStream - val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Double) = inStrm.readObject() - val inMsg2: (VertexId, Double) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("variable long encoding") { - def testVarLongEncoding(v: Long, optimizePositive: Boolean) { - val bout = new ByteArrayOutputStream - val stream = new ShuffleSerializationStream(bout) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive) - this - } - } - stream.writeObject(v) - - val bin = new ByteArrayInputStream(bout.toByteArray) - val dstream = new ShuffleDeserializationStream(bin) { - def readObject[T: ClassTag](): T = { - readVarLong(optimizePositive).asInstanceOf[T] - } - } - val read = dstream.readObject[Long]() - assert(read === v) - } - - // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference) - val d = Random.nextLong() % 128 - Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d, - 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number => - testVarLongEncoding(number, optimizePositive = false) - testVarLongEncoding(number, optimizePositive = true) - testVarLongEncoding(-number, optimizePositive = false) - testVarLongEncoding(-number, optimizePositive = true) - } - } -} diff --git a/mllib/pom.xml b/mllib/pom.xml index de062a4901596..87a7ddaba97f2 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server @@ -65,6 +70,10 @@ junit junit + + org.apache.commons + commons-math3 + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 485abe272326c..70d7138e3060f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.existentials @@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -72,15 +73,11 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], - initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] + initialWeights: Vector): JList[Object] = { // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. learner.disableUncachedWarning() val model = learner.run(data.rdd, initialWeights) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.dumps(model.weights)) - ret.add(model.intercept: java.lang.Double) - ret + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava } /** @@ -91,10 +88,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) lrAlg.optimizer @@ -113,7 +110,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lrAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -125,7 +122,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.optimizer .setNumIterations(numIterations) @@ -135,7 +132,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lassoAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -147,7 +144,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.optimizer .setNumIterations(numIterations) @@ -157,7 +154,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( ridgeAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -169,9 +166,9 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) SVMAlg.optimizer @@ -190,7 +187,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( SVMAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -201,10 +198,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) LogRegAlg.optimizer @@ -223,7 +220,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( LogRegAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -231,13 +228,10 @@ class PythonMLLibAPI extends Serializable { */ def trainNaiveBayes( data: JavaRDD[LabeledPoint], - lambda: Double): java.util.List[java.lang.Object] = { + lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(Vectors.dense(model.labels)) - ret.add(Vectors.dense(model.pi)) - ret.add(model.theta) - ret + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + map(_.asInstanceOf[Object]).asJava } /** @@ -259,6 +253,21 @@ class PythonMLLibAPI extends Serializable { return kMeansAlg.run(data.rdd) } + /** + * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + */ + private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + + def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -266,12 +275,25 @@ class PythonMLLibAPI extends Serializable { * the Py4J documentation. */ def trainALSModel( - ratings: JavaRDD[Rating], + ratingsJRDD: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, - blocks: Int): MatrixFactorizationModel = { - ALS.train(ratings.rdd, rank, iterations, lambda, blocks) + blocks: Int, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** @@ -286,8 +308,23 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int, - alpha: Double): MatrixFactorizationModel = { - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) + alpha: Double, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setImplicitPrefs(true) + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setAlpha(alpha) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** @@ -373,19 +410,16 @@ class PythonMLLibAPI extends Serializable { rdd.rdd.map(model.transform) } - def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(word: String, num: Int): JList[Object] = { val vec = transform(word) findSynonyms(vec, num) } - def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(vector: Vector, num: Int): JList[Object] = { val result = model.findSynonyms(vector, num) val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(words) - ret.add(similarity) - ret + List(words, similarity).map(_.asInstanceOf[Object]).asJava } } @@ -395,13 +429,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on exit; * see the Py4J documentation. * @param data Training data - * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + * @param categoricalFeaturesInfo Categorical features info, as Java map */ def trainDecisionTreeModel( data: JavaRDD[LabeledPoint], algoStr: String, numClasses: Int, - categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + categoricalFeaturesInfo: JMap[Int, Int], impurityStr: String, maxDepth: Int, maxBins: Int, @@ -417,7 +451,7 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) @@ -448,6 +482,31 @@ class PythonMLLibAPI extends Serializable { Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } + /** + * Java stub for mllib Statistics.chiSqTest() + */ + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + if (expected == null) { + Statistics.chiSqTest(observed) + } else { + Statistics.chiSqTest(observed, expected) + } + } + + /** + * Java stub for mllib Statistics.chiSqTest(observed: Matrix) + */ + def chiSqTest(observed: Matrix): ChiSqTestResult = { + Statistics.chiSqTest(observed) + } + + /** + * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint]) + */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = { + Statistics.chiSqTest(data.rdd) + } + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark private def getCorrNameOrDefault(method: String) = { if (method == null) CorrelationNames.defaultCorrName else method @@ -589,7 +648,7 @@ private[spark] object SerDe extends Serializable { if (objects.length == 0 || objects.length > 3) { out.write(Opcodes.MARK) } - objects.foreach(pickler.save(_)) + objects.foreach(pickler.save) val code = objects.length match { case 1 => Opcodes.TUPLE1 case 2 => Opcodes.TUPLE2 @@ -719,7 +778,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } @@ -730,7 +789,7 @@ private[spark] object SerDe extends Serializable { def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => initialize() // let it called in executor - new PythonRDD.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(iter) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala new file mode 100644 index 0000000000000..6189dce9b27da --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * :: DeveloperApi :: + * StreamingKMeansModel extends MLlib's KMeansModel for streaming + * algorithms, so it can keep track of a continuously updated weight + * associated with each cluster, and also update the model by + * doing a single iteration of the standard k-means algorithm. + * + * The update algorithm uses the "mini-batch" KMeans rule, + * generalized to incorporate forgetfullness (i.e. decay). + * The update rule (for each cluster) is: + * + * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t * a + m_t + * + * Where c_t is the previously estimated centroid for that cluster, + * n_t is the number of points assigned to it thus far, x_t is the centroid + * estimated on the current batch, and m_t is the number of points assigned + * to that centroid in the current batch. + * + * The decay factor 'a' scales the contribution of the clusters as estimated thus far, + * by applying a as a discount weighting on the current point when evaluating + * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids + * are determined entirely by recent data. Lower values correspond to + * more forgetting. + * + * Decay can optionally be specified by a half life and associated + * time unit. The time unit can either be a batch of data or a single + * data point. Considering data arrived at time t, the half life h is defined + * such that at time t + h the discount applied to the data from t is 0.5. + * The definition remains the same whether the time unit is given + * as batches or points. + * + */ +@DeveloperApi +class StreamingKMeansModel( + override val clusterCenters: Array[Vector], + val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { + + /** Perform a k-means update on a batch of data. */ + def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { + + // find nearest cluster to each point + val closest = data.map(point => (this.predict(point), (point, 1L))) + + // get sums and counts for updating each cluster + val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => { + BLAS.axpy(1.0, p2._1, p1._1) + (p1._1, p1._2 + p2._2) + } + val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest + .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) + .collect() + + val discount = timeUnit match { + case StreamingKMeans.BATCHES => decayFactor + case StreamingKMeans.POINTS => + val numNewPoints = pointStats.view.map { case (_, (_, n)) => + n + }.sum + math.pow(decayFactor, numNewPoints) + } + + // apply discount to weights + BLAS.scal(discount, Vectors.dense(clusterWeights)) + + // implement update rule + pointStats.foreach { case (label, (sum, count)) => + val centroid = clusterCenters(label) + + val updatedWeight = clusterWeights(label) + count + val lambda = count / math.max(updatedWeight, 1e-16) + + clusterWeights(label) = updatedWeight + BLAS.scal(1.0 - lambda, centroid) + BLAS.axpy(lambda / count, sum, centroid) + + // display the updated cluster centers + val display = clusterCenters(label).size match { + case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...") + case _ => centroid.toArray.mkString("[", ",", "]") + } + + logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display") + } + + // Check whether the smallest cluster is dying. If so, split the largest cluster. + val weightsWithIndex = clusterWeights.view.zipWithIndex + val (maxWeight, largest) = weightsWithIndex.maxBy(_._1) + val (minWeight, smallest) = weightsWithIndex.minBy(_._1) + if (minWeight < 1e-8 * maxWeight) { + logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.") + val weight = (maxWeight + minWeight) / 2.0 + clusterWeights(largest) = weight + clusterWeights(smallest) = weight + val largestClusterCenter = clusterCenters(largest) + val smallestClusterCenter = clusterCenters(smallest) + var j = 0 + while (j < dim) { + val x = largestClusterCenter(j) + val p = 1e-14 * math.max(math.abs(x), 1.0) + largestClusterCenter.toBreeze(j) = x + p + smallestClusterCenter.toBreeze(j) = x - p + j += 1 + } + } + + this + } +} + +/** + * :: DeveloperApi :: + * StreamingKMeans provides methods for configuring a + * streaming k-means analysis, training the model on streaming, + * and using the model to make predictions on streaming data. + * See KMeansModel for details on algorithm and update rules. + * + * Use a builder pattern to construct a streaming k-means analysis + * in an application, like: + * + * val model = new StreamingKMeans() + * .setDecayFactor(0.5) + * .setK(3) + * .setRandomCenters(5, 100.0) + * .trainOn(DStream) + */ +@DeveloperApi +class StreamingKMeans( + var k: Int, + var decayFactor: Double, + var timeUnit: String) extends Logging { + + def this() = this(2, 1.0, StreamingKMeans.BATCHES) + + protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) + + /** Set the number of clusters. */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Set the decay factor directly (for forgetful algorithms). */ + def setDecayFactor(a: Double): this.type = { + this.decayFactor = decayFactor + this + } + + /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { + throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) + } + this.decayFactor = math.exp(math.log(0.5) / halfLife) + logInfo("Setting decay factor to: %g ".format (this.decayFactor)) + this.timeUnit = timeUnit + this + } + + /** Specify initial centers directly. */ + def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + model = new StreamingKMeansModel(centers, weights) + this + } + + /** + * Initialize random centers, requiring only the number of dimensions. + * + * @param dim Number of dimensions + * @param weight Weight for each center + * @param seed Random seed + */ + def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + val random = new XORShiftRandom(seed) + val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) + val weights = Array.fill(k)(weight) + model = new StreamingKMeansModel(centers, weights) + this + } + + /** Return the latest model. */ + def latestModel(): StreamingKMeansModel = { + model + } + + /** + * Update the clustering model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * checks whether the cluster centers have been initialized, + * and updates the model using each batch of data from the stream. + * + * @param data DStream containing vector data + */ + def trainOn(data: DStream[Vector]) { + assertInitialized() + data.foreachRDD { (rdd, time) => + model = model.update(rdd, decayFactor, timeUnit) + } + } + + /** + * Use the clustering model to make predictions on batches of data from a DStream. + * + * @param data DStream containing vector data + * @return DStream containing predictions + */ + def predictOn(data: DStream[Vector]): DStream[Int] = { + assertInitialized() + data.map(model.predict) + } + + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * + * @param data DStream containing (key, feature vector) pairs + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { + assertInitialized() + data.mapValues(model.predict) + } + + /** Check whether cluster centers have been initialized. */ + private[this] def assertInitialized(): Unit = { + if (model.clusterCenters == null) { + throw new IllegalStateException( + "Initial cluster centers must be set before starting predictions") + } + } +} + +private[clustering] object StreamingKMeans { + final val BATCHES = "batches" + final val POINTS = "points" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 7858ec602483f..078fbfbe4f0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve { */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( - seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), combOp = _ + _ ) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala new file mode 100644 index 0000000000000..ea10bde5fa252 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ + +/** + * Evaluator for multilabel classification. + * @param predictionAndLabels an RDD of (predictions, labels) pairs, + * both are non-null Arrays, each with unique elements. + */ +class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + + private lazy val numDocs: Long = predictionAndLabels.count() + + private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => + labels}.distinct().count() + + /** + * Returns subset accuracy + * (for equal sets of labels) + */ + lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => + predictions.deep == labels.deep + }.count().toDouble / numDocs + + /** + * Returns accuracy + */ + lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => + labels.intersect(predictions).size.toDouble / + (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs + + + /** + * Returns Hamming-loss + */ + lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => + labels.size + predictions.size - 2 * labels.intersect(predictions).size + }.sum / (numDocs * numLabels) + + /** + * Returns document-based precision averaged by the number of documents + */ + lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => + if (predictions.size > 0) { + predictions.intersect(labels).size.toDouble / predictions.size + } else { + 0 + } + }.sum / numDocs + + /** + * Returns document-based recall averaged by the number of documents + */ + lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => + labels.intersect(predictions).size.toDouble / labels.size + }.sum / numDocs + + /** + * Returns document-based f1-measure averaged by the number of documents + */ + lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => + 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) + }.sum / numDocs + + private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => + predictions.intersect(labels) + }.countByValue() + + private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => + predictions.diff(labels) + }.countByValue() + + private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) => + labels.diff(predictions) + }.countByValue() + + /** + * Returns precision for a given label (category) + * @param label the label. + */ + def precision(label: Double) = { + val tp = tpPerClass(label) + val fp = fpPerClass.getOrElse(label, 0L) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } + + /** + * Returns recall for a given label (category) + * @param label the label. + */ + def recall(label: Double) = { + val tp = tpPerClass(label) + val fn = fnPerClass.getOrElse(label, 0L) + if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + } + + /** + * Returns f1-measure for a given label (category) + * @param label the label. + */ + def f1Measure(label: Double) = { + val p = precision(label) + val r = recall(label) + if((p + r) == 0) 0 else 2 * p * r / (p + r) + } + + private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp } + private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp } + private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn } + + /** + * Returns micro-averaged label-based precision + * (equals to micro-averaged document-based precision) + */ + lazy val microPrecision = { + val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} + sumTp.toDouble / (sumTp + sumFp) + } + + /** + * Returns micro-averaged label-based recall + * (equals to micro-averaged document-based recall) + */ + lazy val microRecall = { + val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} + sumTp.toDouble / (sumTp + sumFn) + } + + /** + * Returns micro-averaged label-based f1-measure + * (equals to micro-averaged document-based f1-measure) + */ + lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) + + /** + * Returns the sequence of labels in ascending order + */ + lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala new file mode 100644 index 0000000000000..693117d820580 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} + +/** + * :: Experimental :: + * Evaluator for regression. + * + * @param predictionAndObservations an RDD of (prediction, observation) pairs. + */ +@Experimental +class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { + + /** + * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. + */ + private lazy val summary: MultivariateStatisticalSummary = { + val summary: MultivariateStatisticalSummary = predictionAndObservations.map { + case (prediction, observation) => Vectors.dense(observation, observation - prediction) + }.aggregate(new MultivariateOnlineSummarizer())( + (summary, v) => summary.add(v), + (sum1, sum2) => sum1.merge(sum2) + ) + summary + } + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + def explainedVariance: Double = { + 1 - summary.variance(1) / summary.variance(0) + } + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + def meanAbsoluteError: Double = { + summary.normL1(1) / summary.count + } + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + def meanSquaredError: Double = { + val rmse = summary.normL2(1) / math.sqrt(summary.count) + rmse * rmse + } + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + def rootMeanSquaredError: Double = { + summary.normL2(1) / math.sqrt(summary.count) + } + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + def r2: Double = { + 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..ac217edc619ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable { } } +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] +} + /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports @@ -191,6 +254,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b5e403bc8c14d..57c0768084e41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -28,8 +29,8 @@ import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. */ -private[mllib] -class RDDFunctions[T: ClassTag](self: RDD[T]) { +@DeveloperApi +class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding @@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Seq[T]] = { + def sliding(windowSize: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") if (windowSize == 1) { - self.map(Seq(_)) + self.map(Array(_)) } else { new SlidingRDD[T](self, windowSize) } @@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } -private[mllib] +@DeveloperApi object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index dd80782c0f001..35e81fcb3de0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] */ private[mllib] class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) - extends RDD[Seq[T]](parent) { + extends RDD[Array[T]](parent) { require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") - override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) .sliding(windowSize) .withPartial(false) + .map(_.toArray) } override def getPreferredLocations(split: Partition): Seq[String] = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6737a2f4176c2..78acc17f901c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) val rfModel = rf.train(input) - rfModel.trees(0) + rfModel.weakHypotheses(0) } } @@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], @@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], nodeQueue: mutable.Queue[(Int, Node)], - timer: TimeTracker = new TimeTracker): Unit = { + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging { logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: RandomForest.NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } + } + } /** * Performs a sequential aggregation over a partition. @@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) - val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) - // If the example does not reach a node in this group, then nodeIndex = null. - if (nodeInfo != null) { - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) - } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) - } - } + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } + agg } @@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging { // Finally, only best Splits for nodes are collected to driver to construct decision tree. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val nodeToBestSplits = + + val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator @@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging { // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - }.reduceByKey((a, b) => a.merge(b)) + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) @@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging { timer.stop("chooseSplits") + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => @@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging { node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { nodeQueue.enqueue((treeIndex, node.leftNode.get)) @@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging { } } + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala new file mode 100644 index 0000000000000..f729344a682e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * A class that implements Stochastic Gradient Boosting + * for regression and binary classification problems. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * + * @param boostingStrategy Parameters for the gradient boosting algorithm + */ +@Experimental +class GradientBoosting ( + private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + + boostingStrategy.weakLearnerParams.algo = Regression + boostingStrategy.weakLearnerParams.impurity = impurity.Variance + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + boostingStrategy.weakLearnerParams.numClassesForClassification = + boostingStrategy.numClassesForClassification + + boostingStrategy.assertValid() + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return WeightedEnsembleModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { + val algo = boostingStrategy.algo + algo match { + case Regression => GradientBoosting.boost(input, boostingStrategy) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoosting.boost(remappedInput, boostingStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + +} + + +object GradientBoosting extends Logging { + + /** + * Method to train a gradient boosting model. + * + * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] + * is recommended to clearly specify regression. + * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] + * is recommended to clearly specify regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return WeightedEnsembleModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Method to train a gradient boosting classification model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return WeightedEnsembleModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + val algo = boostingStrategy.algo + require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Method to train a gradient boosting regression model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return WeightedEnsembleModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + val algo = boostingStrategy.algo + require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] + */ + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + train(input.rdd, boostingStrategy) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainClassifier(input.rdd, boostingStrategy) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainRegressor(input.rdd, boostingStrategy) + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param boostingStrategy boosting parameters + * @return + */ + private def boost( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + val strategy = boostingStrategy.weakLearnerParams + + // Cache input + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + var data = input + + // Initialize tree + timer.start("building tree 0") + val firstTreeModel = new DecisionTree(strategy).train(data) + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = 1.0 + val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, + Sum) + logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + // psuedo-residual for second iteration + data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), + point.features)) + + var m = 1 + while (m < numIterations) { + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val model = new DecisionTree(strategy).train(data) + timer.stop(s"building tree $m") + // Create partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + // Note: A model of type regression is used since we require raw prediction + val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), + baseLearnerWeights.slice(0, m + 1), Regression, Sum) + logDebug("error of gbt = " + loss.computeError(partialModel, input)) + // Update data with pseudo-residuals + data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), + point.features)) + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) + + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index ebbd8e0257209..9683916d9b3f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -26,8 +26,9 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -59,7 +60,7 @@ import org.apache.spark.util.Utils * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. - * @param seed Random seed for bootstrapping and choosing feature subsets. + * @param seed Random seed for bootstrapping and choosing feature subsets. */ @Experimental private class RandomForest ( @@ -78,9 +79,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): RandomForestModel = { + def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { val timer = new TimeTracker() @@ -111,11 +112,20 @@ private class RandomForest ( // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val baggedInput = if (numTrees > 1) { - BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed) - } else { - BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) - }.persist(StorageLevel.MEMORY_AND_DISK) + + val (subsample, withReplacement) = { + // TODO: Have a stricter check for RF in the strategy + val isRandomForest = numTrees > 1 + if (isRandomForest) { + (1.0, true) + } else { + (strategy.subsamplingRate, false) + } + } + + val baggedInput + = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) + .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -150,6 +160,19 @@ private class RandomForest ( * in lower levels). */ + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointDir = strategy.checkpointDir, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + // FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]() @@ -172,7 +195,7 @@ private class RandomForest ( // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") } @@ -183,8 +206,14 @@ private class RandomForest ( logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + nodeIdCache.get.deleteAllCheckpoints() + } + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - RandomForestModel.build(trees) + val treeWeights = Array.fill[Double](numTrees)(1.0) + new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) } } @@ -205,14 +234,14 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) @@ -243,7 +272,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -254,7 +283,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): RandomForestModel = { + seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) @@ -273,7 +302,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -293,14 +322,14 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) @@ -330,7 +359,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -340,7 +369,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): RandomForestModel = { + seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) @@ -358,7 +387,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala new file mode 100644 index 0000000000000..abbda040bd528 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import scala.beans.BeanProperty + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} + +/** + * :: Experimental :: + * Stores all the configuration options for the boosting algorithms + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. + * @param loss Loss function used for minimization during gradient boosting. + * @param learningRate Learning rate for shrinking the contribution of each estimator. The + * learning rate should be between in the interval (0, 1] + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * This setting overrides any setting in [[weakLearnerParams]]. + * Default value is 2 (binary classification). + * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are + * supported. + */ +@Experimental +case class BoostingStrategy( + // Required boosting parameters + @BeanProperty var algo: Algo, + @BeanProperty var numIterations: Int, + @BeanProperty var loss: Loss, + // Optional boosting parameters + @BeanProperty var learningRate: Double = 0.1, + @BeanProperty var numClassesForClassification: Int = 2, + @BeanProperty var weakLearnerParams: Strategy) extends Serializable { + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + weakLearnerParams.numClassesForClassification = numClassesForClassification + + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification == 2) + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } +} + +@Experimental +object BoostingStrategy { + + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy("Regression") + treeStrategy.maxDepth = 3 + algo match { + case "Classification" => + new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + case "Regression" => + new BoostingStrategy(Algo.withName(algo), 100, SquaredError, + weakLearnerParams = treeStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the boosting.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala new file mode 100644 index 0000000000000..82889dc00cdad --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: Experimental :: + * Enum to select ensemble combining strategy for base learners + */ +@DeveloperApi +object EnsembleCombiningStrategy extends Enumeration { + type EnsembleCombiningStrategy = Value + val Sum, Average = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index caaccbfb8ad16..b5b1f82177edc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree.configuration +import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * for choosing how to split on features at each node. * More bins give higher granularity. * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: - * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, @@ -58,31 +59,35 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will + * maintain a separate RDD of node Id cache for each row. + * @param checkpointDir If the node Id cache is used, it will help to checkpoint + * the node Id cache periodically. This is the checkpoint directory + * to be used for the node Id cache. + * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. + * E.g. 10 means that the cache will get checkpointed every 10 updates. */ @Experimental class Strategy ( - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val numClassesForClassification: Int = 2, - val maxBins: Int = 32, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val minInstancesPerNode: Int = 1, - val minInfoGain: Double = 0.0, - val maxMemoryInMB: Int = 256) extends Serializable { + @BeanProperty var algo: Algo, + @BeanProperty var impurity: Impurity, + @BeanProperty var maxDepth: Int, + @BeanProperty var numClassesForClassification: Int = 2, + @BeanProperty var maxBins: Int = 32, + @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @BeanProperty var minInstancesPerNode: Int = 1, + @BeanProperty var minInfoGain: Double = 0.0, + @BeanProperty var maxMemoryInMB: Int = 256, + @BeanProperty var subsamplingRate: Double = 1, + @BeanProperty var useNodeIdCache: Boolean = false, + @BeanProperty var checkpointDir: Option[String] = None, + @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -99,6 +104,23 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** * Check validity of parameters. * Throws exception if invalid. @@ -130,6 +152,26 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } +} + +@Experimental +object Strategy { + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala index e7a2127c5d2e7..089010c81ffb6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -21,13 +21,14 @@ import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Internal representation of a datapoint which belongs to several subsamples of the same dataset, * particularly for bagging (e.g., for random forests). * * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsample. + * number of times which this instance appears in each subsamplingRate. * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * @@ -44,22 +45,65 @@ private[tree] object BaggedPoint { /** * Convert an input dataset into its BaggedPoint representation, - * choosing subsample counts for each instance. - * Each subsample has the same number of instances as the original dataset, - * and is created by subsampling with replacement. - * @param input Input dataset. - * @param numSubsamples Number of subsamples of this RDD to take. - * @param seed Random seed. - * @return BaggedPoint dataset representation + * choosing subsamplingRate counts for each instance. + * Each subsamplingRate has the same number of instances as the original dataset, + * and is created by subsampling without replacement. + * @param input Input dataset. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param withReplacement Sampling with/without replacement. + * @param seed Random seed. + * @return BaggedPoint dataset representation. */ - def convertToBaggedRDD[Datum]( + def convertToBaggedRDD[Datum] ( input: RDD[Datum], + subsamplingRate: Double, numSubsamples: Int, + withReplacement: Boolean, seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + if (withReplacement) { + convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + } else { + if (numSubsamples == 1 && subsamplingRate == 1.0) { + convertToBaggedRDDWithoutSampling(input) + } else { + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + } + } + } + + private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + seed: Int): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val rng = new XORShiftRandom + rng.setSeed(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + val x = rng.nextDouble() + subsampleWeights(subsampleIndex) = { + if (x < subsamplingRate) 1.0 else 0.0 + } + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDSamplingWithReplacement[Datum] ( + input: RDD[Datum], + subsample: Double, + numSubsamples: Int, + seed: Int): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => - // TODO: Support different sampling rates, and sampling without replacement. // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val poisson = new PoissonDistribution(1.0) + val poisson = new PoissonDistribution(subsample) poisson.reseedRandomGenerator(seed + partitionIndex + 1) instances.map { instance => val subsampleWeights = new Array[Double](numSubsamples) @@ -73,7 +117,8 @@ private[tree] object BaggedPoint { } } - def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + private def convertToBaggedRDDWithoutSampling[Datum] ( + input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { input.map(datum => new BaggedPoint(datum, Array(1.0))) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala new file mode 100644 index 0000000000000..83011b48b7d9b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.model.{Bin, Node, Split} + +/** + * :: DeveloperApi :: + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +@DeveloperApi +private[tree] case class NodeIndexUpdater( + split: Split, + nodeIndex: Int) { + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeatures Binned feature values. + * @param bins Bin information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { + if (split.featureType == Continuous) { + val featureIndex = split.feature + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + if (featureValueUpperBound <= split.threshold) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } else { + if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } + } +} + +/** + * :: DeveloperApi :: + * A given TreePoint would belong to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointDir The checkpoint directory where + * the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +@DeveloperApi +private[tree] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointDir: Option[String], + val checkpointInterval: Int) { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // If a checkpoint directory is given, and there's no prior checkpoint directory, + // then set the checkpoint directory with the given one. + if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { + nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) + } + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param bins Bin information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + bins: Array[Array[Bin]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { + dataPoint => { + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null) + if (nodeIdUpdater != null) { + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeatures = dataPoint._1.datum.binnedFeatures, + bins = bins) + dataPoint._2(treeId) = newNodeIndex + } + + treeId += 1 + } + + dataPoint._2 + } + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && + (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue.get(1).get.getCheckpointFile != None) { + val old = checkpointQueue.dequeue() + + // Since the old checkpoint is not deleted by Spark, + // we'll manually delete it here. + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile != None) { + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } + } + } +} + +@DeveloperApi +private[tree] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointDir: Option[String], + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointDir, + checkpointInterval) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala new file mode 100644 index 0000000000000..d111ffe30ed9e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for least absolute error loss calculation. + * The features x and the corresponding label y is predicted using the function F. + * For each instance: + * Loss: |y - F| + * Negative gradient: sign(y - F) + */ +@DeveloperApi +object AbsoluteError extends Loss { + + /** + * Method to calculate the gradients for the gradient boosting calculation for least + * absolute error calculation. + * @param model Model of the weak learner + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double = { + if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + val sumOfAbsolutes = data.map { y => + val err = model.predict(y.features) - y.label + math.abs(err) + }.sum() + sumOfAbsolutes / data.count() + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala new file mode 100644 index 0000000000000..6f3d4340f0d3b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for least squares error loss calculation. + * + * The features x and the corresponding label y is predicted using the function F. + * For each instance: + * Loss: log(1 + exp(-2yF)), y in {-1, 1} + * Negative gradient: 2y / ( 1 + exp(2yF)) + */ +@DeveloperApi +object LogLoss extends Loss { + + /** + * Method to calculate the loss gradients for the gradient boosting calculation for binary + * classification + * @param model Model of the weak learner + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double = { + val prediction = model.predict(point.features) + 1.0 / (1.0 + math.exp(-prediction)) - point.label + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count() + wrongPredictions / data.count + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala new file mode 100644 index 0000000000000..5580866c879e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. + */ +@DeveloperApi +trait Loss extends Serializable { + + /** + * Method to calculate the gradients for the gradient boosting calculation. + * @param model Model of the weak learner. + * @param point Instance of the training dataset. + * @return Loss gradient. + */ + def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala new file mode 100644 index 0000000000000..42c9ead9884b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +object Losses { + + def fromString(name: String): Loss = name match { + case "leastSquaresError" => SquaredError + case "leastAbsoluteError" => AbsoluteError + case "logLoss" => LogLoss + case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name") + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala new file mode 100644 index 0000000000000..4349fefef2c74 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for least squares error loss calculation. + * + * The features x and the corresponding label y is predicted using the function F. + * For each instance: + * Loss: (y - F)**2/2 + * Negative gradient: y - F + */ +@DeveloperApi +object SquaredError extends Loss { + + /** + * Method to calculate the gradients for the gradient boosting calculation for least + * squares error calculation. + * @param model Model of the weak learner + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double = { + model.predict(point.features) - point.label + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = model.predict(y.features) - y.label + err * err + }.mean() + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala deleted file mode 100644 index 6a22e2abe59bd..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import scala.collection.mutable - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.rdd.RDD - -/** - * :: Experimental :: - * Random forest model for classification or regression. - * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make - * aggregate predictions. - * @param trees Trees which make up this forest. This cannot be empty. - * @param algo algorithm type -- classification or regression - */ -@Experimental -class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable { - - require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - def predict(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - trees.foreach { tree => - val prediction = tree.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - trees.map(_.predict(features)).sum / trees.size - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = { - features.map(x => predict(x)) - } - - /** - * Get number of trees in forest. - */ - def numTrees: Int = trees.size - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum - - /** - * Print a summary of the model. - */ - override def toString: String = algo match { - case Classification => - s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes" - case Regression => - s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes" - case _ => throw new IllegalArgumentException( - s"RandomForestModel given unknown algo parameter: $algo.") - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + trees.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - -} - -private[tree] object RandomForestModel { - - def build(trees: Array[DecisionTreeModel]): RandomForestModel = { - require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") - val algo: Algo = trees(0).algo - require(trees.forall(_.algo == algo), - "RandomForestModel cannot combine trees which have different output types" + - " (classification/regression).") - new RandomForestModel(trees, algo) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala new file mode 100644 index 0000000000000..7b052d9163a13 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +import scala.collection.mutable + +@Experimental +class WeightedEnsembleModel( + val weakHypotheses: Array[DecisionTreeModel], + val weakHypothesisWeights: Array[Double], + val algo: Algo, + val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" + + s". Number of weakHypotheses = $weakHypotheses") + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictRaw(features: Vector): Double = { + val treePredictions = weakHypotheses.map(learner => learner.predict(features)) + if (numWeakHypotheses == 1){ + treePredictions(0) + } else { + var prediction = treePredictions(0) + var index = 1 + while (index < numWeakHypotheses) { + prediction += weakHypothesisWeights(index) * treePredictions(index) + index += 1 + } + prediction + } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + algo match { + case Regression => predictRaw(features) + case Classification => { + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (predictRaw(features) > 0 ) 1.0 else 0.0 + } + case _ => throw new IllegalArgumentException( + s"WeightedEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Predict values for a single data point. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + private def predictByAveraging(features: Vector): Double = { + algo match { + case Classification => + val predictionToCount = new mutable.HashMap[Int, Int]() + weakHypotheses.foreach { learner => + val prediction = learner.predict(features).toInt + predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 + } + predictionToCount.maxBy(_._2)._1 + case Regression => + weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size + } + } + + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + combiningStrategy match { + case Sum => predictBySumming(features) + case Average => predictByAveraging(features) + case _ => throw new IllegalArgumentException( + s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n" + case Regression => + s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n" + case _ => throw new IllegalArgumentException( + s"WeightedEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numWeakHypotheses: Int = weakHypotheses.size + + // TODO: Remove these helpers methods once class is generalized to support any base learning + // algorithms. + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b88e08bf148ae..9353351af72a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliSampler +import org.apache.spark.util.random.BernoulliCellSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel @@ -244,7 +244,7 @@ object MLUtils { def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => - val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, + val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala new file mode 100644 index 0000000000000..850c9fce507cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.random.XORShiftRandom + +class StreamingKMeansSuite extends FunSuite with TestSuiteBase { + + override def maxWaitTimeMillis = 30000 + + test("accuracy for single center and equivalence to grand average") { + // set parameters + val numBatches = 10 + val numPoints = 50 + val k = 1 + val d = 5 + val r = 0.1 + + // create model with one cluster + val model = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // estimated center should be close to true center + assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) + + // estimated center from streaming should exactly match the arithmetic mean of all data points + // because the decay factor is set to 1.0 + val grandMean = + input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) + } + + test("accuracy for two centers") { + val numBatches = 10 + val numPoints = 5 + val k = 2 + val d = 5 + val r = 0.1 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(2, "batches") + .setInitialCenters( + Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), + Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)), + Array(5.0, 5.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + } + + test("detecting dying clusters") { + val numBatches = 10 + val numPoints = 5 + val k = 1 + val d = 1 + val r = 1.0 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(0.5, "points") + .setInitialCenters( + Array(Vectors.dense(0.0), Vectors.dense(1000.0)), + Array(1.0, 1.0)) + + // new data are all around the first cluster 0.0 + val (input, _) = + StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + val model = kMeans.latestModel() + val c0 = model.clusterCenters(0)(0) + val c1 = model.clusterCenters(1)(0) + + assert(c0 * c1 < 0.0, "should have one positive center and one negative center") + // 0.8 is the mean of half-normal distribution + assert(math.abs(c0) ~== 0.8 absTol 0.6) + assert(math.abs(c1) ~== 0.8 absTol 0.6) + } + + def StreamingKMeansDataGenerator( + numPoints: Int, + numBatches: Int, + k: Int, + d: Int, + r: Double, + seed: Int, + initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = { + val rand = new XORShiftRandom(seed) + val centers = initCenters match { + case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) + case _ => initCenters + } + val data = (0 until numBatches).map { i => + (0 until numPoints).map { idx => + val center = centers(idx % k) + Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r)) + } + } + (data, centers) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala new file mode 100644 index 0000000000000..342baa0274e9c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD + +class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { + test("Multilabel evaluation metrics") { + /* + * Documents true labels (5x class0, 3x class1, 4x class2): + * doc 0 - predict 0, 1 - class 0, 2 + * doc 1 - predict 0, 2 - class 0, 1 + * doc 2 - predict none - class 0 + * doc 3 - predict 2 - class 2 + * doc 4 - predict 2, 0 - class 2, 0 + * doc 5 - predict 0, 1, 2 - class 0, 1 + * doc 6 - predict 1 - class 1, 2 + * + * predicted classes + * class 0 - doc 0, 1, 4, 5 (total 4) + * class 1 - doc 0, 5, 6 (total 3) + * class 2 - doc 1, 3, 4, 5 (total 4) + * + * true classes + * class 0 - doc 0, 1, 2, 4, 5 (total 5) + * class 1 - doc 1, 5, 6 (total 3) + * class 2 - doc 0, 3, 4, 6 (total 4) + * + */ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + val metrics = new MultilabelMetrics(scoreAndLabels) + val delta = 0.00001 + val precision0 = 4.0 / (4 + 0) + val precision1 = 2.0 / (2 + 1) + val precision2 = 2.0 / (2 + 2) + val recall0 = 4.0 / (4 + 1) + val recall1 = 2.0 / (2 + 1) + val recall2 = 2.0 / (2 + 2) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val sumTp = 4 + 2 + 2 + assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1)) + val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2) + val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2) + val microF1MeasureClass = 2.0 * sumTp.toDouble / + (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2)) + val macroPrecisionDoc = 1.0 / 7 * + (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) + val macroRecallDoc = 1.0 / 7 * + (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2) + val macroF1MeasureDoc = (1.0 / 7) * + 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) + + 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) ) + val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1) + val strictAccuracy = 2.0 / 7 + val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta) + assert(math.abs(metrics.microRecall - microRecallClass) < delta) + assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta) + assert(math.abs(metrics.precision - macroPrecisionDoc) < delta) + assert(math.abs(metrics.recall - macroRecallDoc) < delta) + assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta) + assert(math.abs(metrics.hammingLoss - hammingLoss) < delta) + assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta) + assert(math.abs(metrics.accuracy - accuracy) < delta) + assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala new file mode 100644 index 0000000000000..5396d7b2b74fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionMetricsSuite extends FunSuite with LocalSparkContext { + + test("regression metrics") { + val predictionAndObservations = sc.parallelize( + Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics with complete fitting") { + val predictionAndObservations = sc.parallelize( + Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2ddf..93a84fe07b32a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 27a19f793242b..4ef67a40b9f49 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.size === data.length) - val sliding = rdd.sliding(3) - val expected = data.flatMap(x => x).sliding(3).toList - assert(sliding.collect().toList === expected) + val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) + val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) + assert(sliding === expected) } test("treeAggregate") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8fc5e111bbc17..c579cb58549f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -493,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(rootNode1.rightNode.nonEmpty) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) // Single group second level tree construction. val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) @@ -786,7 +786,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) val topNode = Node.emptyNode(nodeIndex = 1) assert(topNode.predict.predict === Double.MinValue) @@ -829,7 +829,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) val topNode = Node.emptyNode(nodeIndex = 1) assert(topNode.predict.predict === Double.MinValue) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala new file mode 100644 index 0000000000000..effb7b8259ffb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.util.StatCounter + +import scala.collection.mutable + +object EnsembleTestHelper { + + /** + * Aggregates all values in data, and tests whether the empirical mean and stddev are within + * epsilon of the expected values. + * @param data Every element of the data should be an i.i.d. sample from some distribution. + */ + def testRandomArrays( + data: Array[Array[Double]], + numCols: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double) { + val values = new mutable.ArrayBuffer[Double]() + data.foreach { row => + assert(row.size == numCols) + values ++= row + } + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + def validateClassifier( + model: WeightedEnsembleModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: WeightedEnsembleModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + + def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0 + } else if (i < numInstances / 2) { + 1.0 + } else if (i < numInstances * 0.9) { + 0.0 + } else { + 1.0 + } + val features = Array.fill[Double](numFeatures)(i.toDouble) + arr(i) = new LabeledPoint(label, Vectors.dense(features)) + } + arr + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala new file mode 100644 index 0000000000000..99a02eda60baf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} + +import org.apache.spark.mllib.util.LocalSparkContext + +/** + * Test suite for [[GradientBoosting]]. + */ +class GradientBoostingSuite extends FunSuite with LocalSparkContext { + + test("Regression with continuous features: SquaredError") { + GradientBoostingSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + subsamplingRate = subsamplingRate) + + val dt = DecisionTree.train(remappedInput, treeStrategy) + + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, 1, treeStrategy) + + val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) + assert(gbt.weakHypotheses.size === numIterations) + val gbtTree = gbt.weakHypotheses(0) + + EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) + + // Make sure trees are the same. + assert(gbtTree.toString == dt.toString) + } + } + + test("Regression with continuous features: Absolute Error") { + GradientBoostingSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + subsamplingRate = subsamplingRate) + + val dt = DecisionTree.train(remappedInput, treeStrategy) + + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, numClassesForClassification = 2, treeStrategy) + + val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) + assert(gbt.weakHypotheses.size === numIterations) + val gbtTree = gbt.weakHypotheses(0) + + EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) + + // Make sure trees are the same. + assert(gbtTree.toString == dt.toString) + } + } + + test("Binary classification with continuous features: Log Loss") { + GradientBoostingSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + subsamplingRate = subsamplingRate) + + val dt = DecisionTree.train(remappedInput, treeStrategy) + + val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, + learningRate, numClassesForClassification = 2, treeStrategy) + + val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) + assert(gbt.weakHypotheses.size === numIterations) + val gbtTree = gbt.weakHypotheses(0) + + EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) + + // Make sure trees are the same. + assert(gbtTree.toString == dt.toString) + } + } + +} + +object GradientBoostingSuite { + + // Combinations for estimators, learning rates and subsamplingRate + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index d3eff59aa0409..73c4393c3581a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -25,100 +25,91 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.tree.model.Node import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.util.StatCounter /** * Test suite for [[RandomForest]]. */ class RandomForestSuite extends FunSuite with LocalSparkContext { - - test("BaggedPoint RDD: without subsampling") { - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) - val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd) - baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) - } - } - - test("BaggedPoint RDD: with subsampling") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 1.0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("Binary classification with continuous features:" + - " comparing DecisionTree vs. RandomForest(numTrees = 1)") { - - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) - val rfTree = rf.trees(0) + assert(rf.weakHypotheses.size === 1) + val rfTree = rf.weakHypotheses(0) val dt = DecisionTree.train(rdd, strategy) - RandomForestSuite.validateClassifier(rf, arr, 0.9) + EnsembleTestHelper.validateClassifier(rf, arr, 0.9) DecisionTreeSuite.validateClassifier(dt, arr, 0.9) // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - test("Regression with continuous features:" + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache :" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + binaryClassificationTestWithContinuousFeatures(strategy) + } - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + def regressionTestWithContinuousFeatures(strategy: Strategy) { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Regression, impurity = Variance, - maxDepth = 2, maxBins = 10, numClassesForClassification = 2, - categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) - val rfTree = rf.trees(0) + assert(rf.weakHypotheses.size === 1) + val rfTree = rf.weakHypotheses(0) val dt = DecisionTree.train(rdd, strategy) - RandomForestSuite.validateRegressor(rf, arr, 0.01) + EnsembleTestHelper.validateRegressor(rf, arr, 0.01) DecisionTreeSuite.validateRegressor(dt, arr, 0.01) // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - test("Binary classification with continuous features: subsampling features") { - val numFeatures = 50 - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures) - val rdd = sc.parallelize(arr) + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) + regressionTestWithContinuousFeatures(strategy) + } - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + regressionTestWithContinuousFeatures(strategy) + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) { + val numFeatures = 50 + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) + val rdd = sc.parallelize(arr) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( @@ -174,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + test("alternating categorical and continuous features with multiclass labels to test indexing") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) @@ -187,77 +192,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - RandomForestSuite.validateClassifier(model, arr, 0.0) + EnsembleTestHelper.validateClassifier(model, arr, 1.0) } - } -object RandomForestSuite { - - /** - * Aggregates all values in data, and tests whether the empirical mean and stddev are within - * epsilon of the expected values. - * @param data Every element of the data should be an i.i.d. sample from some distribution. - */ - def testRandomArrays( - data: Array[Array[Double]], - numCols: Int, - expectedMean: Double, - expectedStddev: Double, - epsilon: Double) { - val values = new mutable.ArrayBuffer[Double]() - data.foreach { row => - assert(row.size == numCols) - values ++= row - } - val stats = new StatCounter(values) - assert(math.abs(stats.mean - expectedMean) < epsilon) - assert(math.abs(stats.stdev - expectedStddev) < epsilon) - } - - def validateClassifier( - model: RandomForestModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { - val predictions = input.map(x => model.predict(x.features)) - val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => - prediction != expected.label - } - val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy, - s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") - } - def validateRegressor( - model: RandomForestModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { - val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") - } - - def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = { - val numInstances = 1000 - val arr = new Array[LabeledPoint](numInstances) - for (i <- 0 until numInstances) { - val label = if (i < numInstances / 10) { - 0.0 - } else if (i < numInstances / 2) { - 1.0 - } else if (i < numInstances * 0.9) { - 0.0 - } else { - 1.0 - } - val features = Array.fill[Double](numFeatures)(i.toDouble) - arr(i) = new LabeledPoint(label, Vectors.dense(features)) - } - arr - } - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala new file mode 100644 index 0000000000000..5cb433232e714 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.EnsembleTestHelper +import org.apache.spark.mllib.util.LocalSparkContext + +/** + * Test suite for [[BaggedPoint]]. + */ +class BaggedPointSuite extends FunSuite with LocalSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } +} diff --git a/network/common/pom.xml b/network/common/pom.xml index e3b7e328701b4..8b24ebf1ba1f2 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -27,12 +27,12 @@ org.apache.spark - network + spark-network-common_2.10 jar - Shuffle Streaming Service + Spark Project Networking http://spark.apache.org/ - network + network-common @@ -50,6 +50,7 @@ com.google.guava guava + 11.0.2 provided @@ -59,6 +60,11 @@ junit test + + com.novocode + junit-interface + test + log4j log4j @@ -69,25 +75,36 @@ mockito-all test + + org.scalatest + scalatest_${scala.binary.version} + test + - - target/java/classes - target/java/test-classes + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + org.apache.maven.plugins - maven-surefire-plugin - 2.17 - - false - - **/Test*.java - **/*Test.java - **/*Suite.java - - + maven-jar-plugin + 2.2 + + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + + diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 854aa6685f85f..5bc6e5a2418a9 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -52,26 +56,39 @@ public class TransportContext { private final Logger logger = LoggerFactory.getLogger(TransportContext.class); private final TransportConf conf; - private final StreamManager streamManager; private final RpcHandler rpcHandler; private final MessageEncoder encoder; private final MessageDecoder decoder; - public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) { + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.conf = conf; - this.streamManager = streamManager; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.newArrayList()); + } + + /** Create a server which will attempt to bind to a specific port. */ + public TransportServer createServer(int port) { + return new TransportServer(this, port); } + /** Creates a new server, binding to any available ephemeral port. */ public TransportServer createServer() { - return new TransportServer(this); + return new TransportServer(this, 0); } /** @@ -109,7 +126,7 @@ private TransportChannelHandler createChannelHandler(Channel channel) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - streamManager, rpcHandler); + rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler); } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 89ed79bc63903..844eff4f4c701 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -30,24 +30,20 @@ import io.netty.channel.DefaultFileRegion; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.network.util.TransportConf; /** * A {@link ManagedBuffer} backed by a segment in a file. */ public final class FileSegmentManagedBuffer extends ManagedBuffer { - - /** - * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). - * Avoid unless there's a good reason not to. - */ - // TODO: Make this configurable - private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; - + private final TransportConf conf; private final File file; private final long offset; private final long length; - public FileSegmentManagedBuffer(File file, long offset, long length) { + public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) { + this.conf = conf; this.file = file; this.offset = offset; this.length = length; @@ -64,7 +60,7 @@ public ByteBuffer nioByteBuffer() throws IOException { try { channel = new RandomAccessFile(file, "r").getChannel(); // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. - if (length < MIN_MEMORY_MAP_BYTES) { + if (length < conf.memoryMapBytes()) { ByteBuffer buf = ByteBuffer.allocate((int) length); channel.position(offset); while (buf.remaining() != 0) { @@ -101,7 +97,7 @@ public InputStream createInputStream() throws IOException { try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return ByteStreams.limit(is, length); + return new LimitedInputStream(is, length); } catch (IOException e) { try { if (is != null) { @@ -133,8 +129,12 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - FileChannel fileChannel = new FileInputStream(file).getChannel(); - return new DefaultFileRegion(fileChannel, offset, length); + if (conf.lazyFileDescriptor()) { + return new LazyFileRegion(file, offset, length); + } else { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } } public File getFile() { return file; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java new file mode 100644 index 0000000000000..81bc8ec40fc82 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.FileInputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Objects; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A FileRegion implementation that only creates the file descriptor when the region is being + * transferred. This cannot be used with Epoll because there is no native support for it. + * + * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we + * should push this into Netty so the native Epoll transport can support this feature. + */ +public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final File file; + private final long position; + private final long count; + + private FileChannel channel; + + private long numBytesTransferred = 0L; + + /** + * @param file file to transfer. + * @param position start position for the transfer. + * @param count number of bytes to transfer starting from position. + */ + public LazyFileRegion(File file, long position, long count) { + this.file = file; + this.position = position; + this.count = count; + } + + @Override + protected void deallocate() { + JavaUtils.closeQuietly(channel); + } + + @Override + public long position() { + return position; + } + + @Override + public long transfered() { + return numBytesTransferred; + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (channel == null) { + channel = new FileInputStream(file).getChannel(); + } + + long count = this.count - position; + if (count < 0 || position < 0) { + throw new IllegalArgumentException( + "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')'); + } + + if (count == 0) { + return 0L; + } + + long written = channel.transferTo(this.position + position, count, target); + if (written > 0) { + numBytesTransferred += written; + } + return written; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("position", position) + .add("count", count) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index b1732fcde21f1..4e944114e8176 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -18,10 +18,15 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -113,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -129,7 +138,7 @@ public void sendRpc(byte[] message, final RpcResponseCallback callback) { final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); - final long requestId = UUID.randomUUID().getLeastSignificantBits(); + final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( @@ -144,16 +153,56 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); - callback.onFailure(new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); } + /** + * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to + * a specified timeout for a response. + */ + public byte[] sendRpcSync(byte[] message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); + + sendRpc(message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + result.set(response); + } + + @Override + public void onFailure(Throwable e) { + result.setException(e); + } + }); + + try { + return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + throw Throwables.propagate(e.getCause()); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000000..65e8020e34121 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 10eb9ef7a025f..397d3a8455c86 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -18,13 +18,17 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -47,22 +51,29 @@ * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; + private final List clientBootstraps; private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; - private final EventLoopGroup workerGroup; + private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { - this.context = context; + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); @@ -72,21 +83,28 @@ public TransportClientFactory(TransportContext context) { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException { + public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); - if (cachedClient != null && cachedClient.isActive()) { - return cachedClient; - } else if (cachedClient != null) { - connectionPool.remove(address, cachedClient); // Remove inactive clients. + if (cachedClient != null) { + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } else { + logger.info("Found inactive connection to {}, closing it.", address); + connectionPool.remove(address, cachedClient); // Remove inactive clients. + } } logger.debug("Creating new connection to " + address); @@ -102,33 +120,55 @@ public TransportClient createClient(String remoteHost, int remotePort) throws Ti // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference client = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new TimeoutException( + throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { - throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); + } + + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); } + long postBootstrap = System.currentTimeMillis(); - // Successful connection - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + // Successful connection & bootstrap -- in the event that two threads raced to create a client, + // use the first one that was put into the connectionPool and close the one we made here. + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); return oldClient; } } @@ -147,6 +187,7 @@ public void close() { if (workerGroup != null) { workerGroup.shutdownGracefully(); + workerGroup = null; } } @@ -158,7 +199,7 @@ public void close() { */ private PooledByteBufAllocator createPooledByteBufAllocator() { return new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), getPrivateStaticField("DEFAULT_PAGE_SIZE"), diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index d8965590b34da..2044afb0d85db 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.client; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -94,7 +95,7 @@ public void channelUnregistered() { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); - failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 152af98ced7ce..986957c1509fd 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; + return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static ChunkFetchFailure decode(ByteBuf buf) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new ChunkFetchFailure(streamChunkId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java new file mode 100644 index 0000000000000..873c694250942 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** Provides a canonical set of Encoders for simple types. */ +public class Encoders { + + /** Strings are encoded with their length followed by UTF-8 bytes. */ + public static class Strings { + public static int encodedLength(String s) { + return 4 + s.getBytes(Charsets.UTF_8).length; + } + + public static void encode(ByteBuf buf, String s) { + byte[] bytes = s.getBytes(Charsets.UTF_8); + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + } + + public static String decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return new String(bytes, Charsets.UTF_8); + } + } + + /** Byte arrays are encoded with their length followed by bytes. */ + public static class ByteArrays { + public static int encodedLength(byte[] arr) { + return 4 + arr.length; + } + + public static void encode(ByteBuf buf, byte[] arr) { + buf.writeInt(arr.length); + buf.writeBytes(arr); + } + + public static byte[] decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return bytes; + } + } + + /** String arrays are encoded with the number of strings followed by per-String encoding. */ + public static class StringArrays { + public static int encodedLength(String[] strings) { + int totalLength = 4; + for (String s : strings) { + totalLength += Strings.encodedLength(s); + } + return totalLength; + } + + public static void encode(ByteBuf buf, String[] strings) { + buf.writeInt(strings.length); + for (String s : strings) { + Strings.encode(buf, s); + } + } + + public static String[] decode(ByteBuf buf) { + int numStrings = buf.readInt(); + String[] strings = new String[numStrings]; + for (int i = 0; i < strings.length; i ++) { + strings[i] = Strings.decode(buf); + } + return strings; + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 4cb8becc3ed22..91d1e8a538a77 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; - ByteBuf header = ctx.alloc().buffer(headerLength); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); in.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index e239d4ffbd29c..ebd764eb5eb5f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) { @Override public int encodedLength() { - return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; + return 8 + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static RpcFailure decode(ByteBuf buf) { long requestId = buf.readLong(); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new RpcFailure(requestId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 099e934ae018c..cdee0b0e0316b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + 4 + message.length; + return 8 + Encoders.ByteArrays.encodedLength(message); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(message.length); - buf.writeBytes(message); + Encoders.ByteArrays.encode(buf, message); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - int messageLen = buf.readInt(); - byte[] message = new byte[messageLen]; - buf.readBytes(message); + byte[] message = Encoders.ByteArrays.decode(buf); return new RpcRequest(requestId, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index ed479478325b6..0a62e09a8115c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -36,20 +36,17 @@ public RpcResponse(long requestId, byte[] response) { public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + 4 + response.length; } + public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(response.length); - buf.writeBytes(response); + Encoders.ByteArrays.encode(buf, response); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - int responseLen = buf.readInt(); - byte[] response = new byte[responseLen]; - buf.readBytes(response); + byte[] response = Encoders.ByteArrays.decode(buf); return new RpcResponse(requestId, response); } diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java similarity index 68% rename from network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 7aa37efc582e4..1502b7489e864 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,4 +1,6 @@ -package org.apache.spark.network;/* +package org.apache.spark.network.server; + +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -17,12 +19,20 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.RpcHandler; -/** Test RpcHandler which always returns a zero-sized success. */ -public class NoOpRpcHandler implements RpcHandler { +/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ +public class NoOpRpcHandler extends RpcHandler { + private final StreamManager streamManager; + + public NoOpRpcHandler() { + streamManager = new OneForOneStreamManager(); + } + @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - callback.onSuccess(new byte[0]); + throw new UnsupportedOperationException("Cannot handle messages"); } + + @Override + public StreamManager getStreamManager() { return streamManager; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java similarity index 93% rename from network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java rename to network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 9688705569634..731d48d4d9c6c 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -30,10 +30,10 @@ /** * StreamManager which allows registration of an Iterator, which are individually - * fetched as chunks by the client. + * fetched as chunks by the client. Each registered buffer is one chunk. */ -public class DefaultStreamManager extends StreamManager { - private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class); +public class OneForOneStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; private final Map streams; @@ -51,7 +51,7 @@ private static class StreamState { } } - public DefaultStreamManager() { + public OneForOneStreamManager() { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index f54a696b8ff79..2ba92a40f8b0a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,16 +23,33 @@ /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); + + /** + * Returns the StreamManager which contains the state about which streams are currently being + * fetched by a TransportClient. + */ + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 352f865935b11..1580180cc17e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -56,24 +56,23 @@ public class TransportRequestHandler extends MessageHandler { /** Client on the same channel allowing us to talk back to the requester. */ private final TransportClient reverseClient; - /** Returns each chunk part of a stream. */ - private final StreamManager streamManager; - /** Handles all RPC messages. */ private final RpcHandler rpcHandler; + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + /** List of all stream ids that have been read on this handler, used for cleanup. */ private final Set streamIds; public TransportRequestHandler( Channel channel, TransportClient reverseClient, - StreamManager streamManager, RpcHandler rpcHandler) { this.channel = channel; this.reverseClient = reverseClient; - this.streamManager = streamManager; this.rpcHandler = rpcHandler; + this.streamManager = rpcHandler.getStreamManager(); this.streamIds = Sets.newHashSet(); } @@ -87,6 +86,7 @@ public void channelUnregistered() { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 243070750d6e7..579676c2c3564 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,11 +50,12 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - public TransportServer(TransportContext context) { + /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + public TransportServer(TransportContext context, int portToBind) { this.context = context; this.conf = context.getConf(); - init(); + init(portToBind); } public int getPort() { @@ -63,18 +65,21 @@ public int getPort() { return port; } - private void init() { + private void init(int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = - NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; + PooledByteBufAllocator allocator = new PooledByteBufAllocator( + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + .option(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.ALLOCATOR, allocator); if (conf.backLog() > 0) { bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); @@ -95,7 +100,7 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); @@ -105,7 +110,7 @@ protected void initChannel(SocketChannel ch) throws Exception { @Override public void close() { if (channelFuture != null) { - // close is a local operation and should finish with milliseconds; timeout just to be safe + // close is a local operation and should finish within milliseconds; timeout just to be safe channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); channelFuture = null; } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 32ba3f5b07f7a..bf8a1fc42fc6d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,22 +17,114 @@ package org.apache.spark.network.util; +import java.nio.ByteBuffer; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Closeable; +import java.io.File; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import com.google.common.base.Preconditions; import com.google.common.io.Closeables; +import com.google.common.base.Charsets; +import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * General utilities available in the network package. Many of these are sourced from Spark's + * own Utils, just accessible within this package. + */ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { - closeable.close(); + if (closeable != null) { + closeable.close(); + } } catch (IOException e) { logger.error("IOException should not have been thrown.", e); } } + + /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ + public static int nonNegativeHash(Object obj) { + if (obj == null) { return 0; } + int hash = obj.hashCode(); + return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; + } + + /** + * Convert the given string to a byte buffer. The resulting buffer can be + * converted back to the same string through {@link #bytesToString(ByteBuffer)}. + */ + public static ByteBuffer stringToBytes(String s) { + return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be + * converted back to the same byte buffer through {@link #stringToBytes(String)}. + */ + public static String bytesToString(ByteBuffer b) { + return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + } + + /* + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + public static void deleteRecursively(File file) throws IOException { + if (file == null) { return; } + + if (file.isDirectory() && !isSymlink(file)) { + IOException savedIOException = null; + for (File child : listFilesSafely(file)) { + try { + deleteRecursively(child); + } catch (IOException e) { + // In case of multiple exceptions, only last one will be thrown + savedIOException = e; + } + } + if (savedIOException != null) { + throw savedIOException; + } + } + + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } + } + + private static File[] listFilesSafely(File file) throws IOException { + if (file.exists()) { + File[] files = file.listFiles(); + if (files == null) { + throw new IOException("Failed to list files for dir: " + file); + } + return files; + } else { + return new File[0]; + } + } + + private static boolean isSymlink(File file) throws IOException { + Preconditions.checkNotNull(file); + File fileInCanonicalDir = null; + if (file.getParent() == null) { + fileInCanonicalDir = file; + } else { + fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); + } + return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 0000000000000..63ca43c046525 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index b1872341198e0..2a7664fe89388 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -37,13 +37,17 @@ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. */ public class NettyUtils { - /** Creates a Netty EventLoopGroup based on the IOMode. */ - public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - - ThreadFactory threadFactory = new ThreadFactoryBuilder() + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ + public static ThreadFactory createThreadFactory(String threadPoolPrefix) { + return new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") + .setNameFormat(threadPoolPrefix + "-%d") .build(); + } + + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + ThreadFactory threadFactory = createThreadFactory(threadPrefix); switch (mode) { case NIO: diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java similarity index 96% rename from network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java rename to network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java index f4e0a2426a3d2..5f20b70678d1e 100644 --- a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java +++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network; +package org.apache.spark.network.util; import java.util.NoSuchElementException; diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 80f65d98032da..621427d8cba5e 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -27,12 +27,14 @@ public TransportConf(ConfigProvider conf) { this.conf = conf; } - /** Port the server listens on. Default to a random port. */ - public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + } + /** Connect timeout in secs. Default 120 secs. */ public int connectionTimeoutMs() { return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; @@ -58,4 +60,36 @@ public int connectionTimeoutMs() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } + + /** + * Minimum size of a block that we should start using memory map rather than reading in through + * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, + * memory mapping has high overhead for blocks close to or below the page size of the OS. + */ + public int memoryMapBytes() { + return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); + } + + /** + * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * created only when data is going to be transferred. This can reduce the number of open files. + */ + public boolean lazyFileDescriptor() { + return conf.getBoolean("spark.shuffle.io.lazyFD", true); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 738dca9b6a9ee..dfb7740344ed0 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -41,10 +41,13 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { @@ -60,6 +63,8 @@ public class ChunkFetchIntegrationSuite { static ManagedBuffer bufferChunk; static ManagedBuffer fileChunk; + private TransportConf transportConf; + @BeforeClass public static void setUp() throws Exception { int bufSize = 100000; @@ -77,9 +82,10 @@ public static void setUp() throws Exception { new Random().nextBytes(fileContent); fp.write(fileContent); fp.close(); - fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { @@ -87,13 +93,24 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { if (chunkIndex == BUFFER_CHUNK_INDEX) { return new NioManagedBuffer(buf); } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); } else { throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); } } }; - TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler()); + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 9f216dd2d722d..64b457b4b3f01 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -35,9 +35,11 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { @@ -61,8 +63,11 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback throw new RuntimeException("Thrown: " + parts[1]); } } + + @Override + public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; - TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler); + TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 3ef964616f0c5..822bef1d81b2a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.util.concurrent.TimeoutException; import org.junit.After; @@ -28,11 +29,11 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -44,9 +45,8 @@ public class TransportClientFactorySuite { @Before public void setUp() { conf = new TransportConf(new SystemPropertyConfigProvider()); - StreamManager streamManager = new DefaultStreamManager(); RpcHandler rpcHandler = new NoOpRpcHandler(); - context = new TransportContext(conf, streamManager, rpcHandler); + context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); server2 = context.createServer(); } @@ -58,7 +58,7 @@ public void tearDown() { } @Test - public void createAndReuseBlockClients() throws TimeoutException { + public void createAndReuseBlockClients() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); @@ -71,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException { } @Test - public void neverReturnInactiveClients() throws Exception { + public void neverReturnInactiveClients() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -89,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception { } @Test - public void closeBlockClientsWithFactory() throws TimeoutException { + public void closeBlockClientsWithFactory() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml new file mode 100644 index 0000000000000..27c8467687f10 --- /dev/null +++ b/network/shuffle/pom.xml @@ -0,0 +1,97 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-shuffle_2.10 + jar + Spark Project Shuffle Streaming Service + http://spark.apache.org/ + + network-shuffle + + + + + + org.apache.spark + spark-network-common_2.10 + ${project.version} + + + org.slf4j + slf4j-api + + + + + com.google.guava + guava + 11.0.2 + provided + + + + + org.apache.spark + spark-network-common_2.10 + ${project.version} + test-jar + test + + + junit + junit + test + + + com.novocode + junit-interface + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 0000000000000..7bc91e375371f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..cad76ab7aa54e --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); + } + + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000000..3777a18e33f78 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000000..81d5766794688 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java new file mode 100644 index 0000000000000..351c7930a900f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.lang.Override; +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; + +/** + * A class that manages shuffle secret used by the external shuffle service. + */ +public class ShuffleSecretManager implements SecretKeyHolder { + private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); + private final ConcurrentHashMap shuffleSecretMap; + + // Spark user used for authenticating SASL connections + // Note that this must match the value in org.apache.spark.SecurityManager + private static final String SPARK_SASL_USER = "sparkSaslUser"; + + public ShuffleSecretManager() { + shuffleSecretMap = new ConcurrentHashMap(); + } + + /** + * Register an application with its secret. + * Executors need to first authenticate themselves with the same secret before + * fetching shuffle files written by other executors in this application. + */ + public void registerApp(String appId, String shuffleSecret) { + if (!shuffleSecretMap.contains(appId)) { + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); + } else { + logger.debug("Application {} already registered", appId); + } + } + + /** + * Register an application with its secret specified as a byte buffer. + */ + public void registerApp(String appId, ByteBuffer shuffleSecret) { + registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); + } + + /** + * Unregister an application along with its secret. + * This is called when the application terminates. + */ + public void unregisterApp(String appId) { + if (shuffleSecretMap.contains(appId)) { + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); + } else { + logger.warn("Attempted to unregister application {} when it is not registered", appId); + } + } + + /** + * Return the Spark user for authenticating SASL connections. + */ + @Override + public String getSaslUser(String appId) { + return SPARK_SASL_USER; + } + + /** + * Return the secret key registered with the given application. + * This key is used to authenticate the executors before they can fetch shuffle files + * written by this application from the external shuffle service. If the specified + * application is not registered, return null. + */ + @Override + public String getSecretKey(String appId) { + return shuffleSecretMap.get(appId); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000000..9abad1f30a259 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000000..e87b17ead1e1a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); + } +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java similarity index 73% rename from core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 645793fde806d..138fd5389c20a 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -15,28 +15,22 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.shuffle; -import java.util.EventListener +import java.util.EventListener; -import org.apache.spark.network.buffer.ManagedBuffer - - -/** - * Listener callback interface for [[BlockTransferService.fetchBlocks]]. - */ -private[spark] -trait BlockFetchingListener extends EventListener { +import org.apache.spark.network.buffer.ManagedBuffer; +public interface BlockFetchingListener extends EventListener { /** * Called once per successfully fetched block. After this call returns, data will be released * automatically. If the data will be passed to another thread, the receiver should retain() * and release() the buffer on their own, or copy the data to a new buffer. */ - def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit + void onBlockFetchSuccess(String blockId, ManagedBuffer data); /** * Called at least once per block upon failures. */ - def onBlockFetchFailure(blockId: String, exception: Throwable): Unit + void onBlockFetchFailure(String blockId, Throwable exception); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java new file mode 100644 index 0000000000000..46ca9708621b9 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; +import org.apache.spark.network.util.TransportConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; + +/** + * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. + * + * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered + * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- + * level shuffle block. + */ +public class ExternalShuffleBlockHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); + + private final ExternalShuffleBlockManager blockManager; + private final OneForOneStreamManager streamManager; + + public ExternalShuffleBlockHandler(TransportConf conf) { + this(new OneForOneStreamManager(), new ExternalShuffleBlockManager(conf)); + } + + /** Enables mocking out the StreamManager and BlockManager. */ + @VisibleForTesting + ExternalShuffleBlockHandler( + OneForOneStreamManager streamManager, + ExternalShuffleBlockManager blockManager) { + this.streamManager = streamManager; + this.blockManager = blockManager; + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + + if (msgObj instanceof OpenBlocks) { + OpenBlocks msg = (OpenBlocks) msgObj; + List blocks = Lists.newArrayList(); + + for (String blockId : msg.blockIds) { + blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); + } + long streamId = streamManager.registerStream(blocks.iterator()); + logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + + } else if (msgObj instanceof RegisterExecutor) { + RegisterExecutor msg = (RegisterExecutor) msgObj; + blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); + callback.onSuccess(new byte[0]); + + } else { + throw new UnsupportedOperationException("Unexpected message: " + msgObj); + } + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + + /** + * Removes an application (once it has been terminated), and optionally will clean up any + * local directories associated with the executors of that application in a separate thread. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + blockManager.applicationRemoved(appId, cleanupLocalDirs); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java new file mode 100644 index 0000000000000..dfe0ba0595090 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Manages converting shuffle BlockIds into physical segments of local files, from a process outside + * of Executors. Each Executor must register its own configuration about where it stores its files + * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated + * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager. + * + * Executors with shuffle file consolidation are not currently supported, as the index is stored in + * the Executor's memory, unlike the IndexShuffleBlockManager. + */ +public class ExternalShuffleBlockManager { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); + + // Map containing all registered executors' metadata. + private final ConcurrentMap executors; + + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + private final TransportConf conf; + + public ExternalShuffleBlockManager(TransportConf conf) { + // TODO: Give this thread a name. + this(conf, Executors.newSingleThreadExecutor()); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockManager(TransportConf conf, Executor directoryCleaner) { + this.conf = conf; + this.executors = Maps.newConcurrentMap(); + this.directoryCleaner = directoryCleaner; + } + + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ + public void registerExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + AppExecId fullId = new AppExecId(appId, execId); + logger.info("Registered executor {} with {}", fullId, executorInfo); + executors.put(fullId, executorInfo); + } + + /** + * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the + * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make + * assumptions about how the hash and sort based shuffles store their data. + */ + public ManagedBuffer getBlockData(String appId, String execId, String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length < 4) { + throw new IllegalArgumentException("Unexpected block id format: " + blockId); + } else if (!blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); + } + int shuffleId = Integer.parseInt(blockIdParts[1]); + int mapId = Integer.parseInt(blockIdParts[2]); + int reduceId = Integer.parseInt(blockIdParts[3]); + + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + + if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else { + throw new UnsupportedOperationException( + "Unsupported shuffle manager: " + executor.shuffleManager); + } + } + + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + + /** + * Hash-based shuffle data is simply stored as one file per block. + * This logic is from FileShuffleBlockManager. + */ + // TODO: Support consolidated hash shuffle files + private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { + File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); + } + + /** + * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file + * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager, + * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + */ + private ManagedBuffer getSortBasedShuffleBlockData( + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + + DataInputStream in = null; + try { + in = new DataInputStream(new FileInputStream(indexFile)); + in.skipBytes(reduceId * 8); + long offset = in.readLong(); + long nextOffset = in.readLong(); + return new FileSegmentManagedBuffer( + conf, + getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + offset, + nextOffset - offset); + } catch (IOException e) { + throw new RuntimeException("Failed to open file: " + indexFile, e); + } finally { + if (in != null) { + JavaUtils.closeQuietly(in); + } + } + } + + /** + * Hashes a filename into the corresponding local directory, in a manner consistent with + * Spark's DiskBlockManager.getFile(). + */ + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + int hash = JavaUtils.nonNegativeHash(filename); + String localDir = localDirs[hash % localDirs.length]; + int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; + return new File(new File(localDir, String.format("%02x", subDirId)), filename); + } + + /** Simply encodes an executor's full ID, which is appId + execId. */ + private static class AppExecId { + final String appId; + final String execId; + + private AppExecId(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java new file mode 100644 index 0000000000000..6e8018b723dc6 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.List; + +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.util.TransportConf; + +/** + * Client for reading shuffle blocks which points to an external (outside of executor) server. + * This is instead of reading shuffle blocks directly from other executors (via + * BlockTransferService), which has the downside of losing the shuffle data if we lose the + * executors. + */ +public class ExternalShuffleClient extends ShuffleClient { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); + + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; + + private TransportClientFactory clientFactory; + private String appId; + + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; + } + + @Override + public void init(String appId) { + this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); + } + + @Override + public void fetchBlocks( + final String host, + final int port, + final String execId, + String[] blockIds, + BlockFetchingListener listener) { + assert appId != null : "Called before init()"; + logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); + try { + RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = + new RetryingBlockFetcher.BlockFetchStarter() { + @Override + public void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); + } else { + blockFetchStarter.createAndStart(blockIds, listener); + } + } catch (Exception e) { + logger.error("Exception while beginning fetchBlocks", e); + for (String blockId : blockIds) { + listener.onBlockFetchFailure(blockId, e); + } + } + } + + /** + * Registers this executor with an external shuffle server. This registration is required to + * inform the shuffle server about where and how we store our shuffle files. + * + * @param host Host of shuffle server. + * @param port Port of shuffle server. + * @param execId This Executor's id. + * @param executorInfo Contains all info necessary for the service to find our shuffle files. + */ + public void registerWithShuffleServer( + String host, + int port, + String execId, + ExecutorShuffleInfo executorInfo) throws IOException { + assert appId != null : "Called before init()"; + TransportClient client = clientFactory.createClient(host, port); + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } + + @Override + public void close() { + clientFactory.close(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java new file mode 100644 index 0000000000000..8ed2e0b39ad23 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.Arrays; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; + +/** + * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and + * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC + * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, + * and Java serialization is used. + * + * Note that this typically corresponds to a + * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. + */ +public class OneForOneBlockFetcher { + private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); + + private final TransportClient client; + private final OpenBlocks openMessage; + private final String[] blockIds; + private final BlockFetchingListener listener; + private final ChunkReceivedCallback chunkCallback; + + private StreamHandle streamHandle = null; + + public OneForOneBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener) { + this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); + this.blockIds = blockIds; + this.listener = listener; + this.chunkCallback = new ChunkCallback(); + } + + /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ + private class ChunkCallback implements ChunkReceivedCallback { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, e); + } + } + + /** + * Begins the fetching process, calling the listener with every block fetched. + * The given message will be serialized with the Java serializer, and the RPC must return a + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. + */ + public void start() { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + + client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + try { + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (int i = 0; i < streamHandle.numChunks; i++) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } + } catch (Exception e) { + logger.error("Failed while starting block fetches after success", e); + failRemainingBlocks(blockIds, e); + } + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + }); + } + + /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ + private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { + for (String blockId : failedBlockIds) { + try { + listener.onBlockFetchFailure(blockId, e); + } catch (Exception e2) { + logger.error("Error in block fetch failure callback", e2); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java new file mode 100644 index 0000000000000..f8a1a266863bb --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to + * IOExceptions, which we hope are due to transient network conditions. + * + * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In + * particular, the listener will be invoked exactly once per blockId, with a success or failure. + */ +public class RetryingBlockFetcher { + + /** + * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any + * remaining blocks. + */ + public static interface BlockFetchStarter { + /** + * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous + * bootstrapping followed by fully asynchronous block fetching. + * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this + * method must throw an exception. + * + * This method should always attempt to get a new TransportClient from the + * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection + * issues. + */ + void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + } + + /** Shared executor service used for waiting and retrying. */ + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Block Fetch Retry")); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); + + /** Used to initiate new Block Fetches on our remaining blocks. */ + private final BlockFetchStarter fetchStarter; + + /** Parent listener which we delegate all successful or permanently failed block fetches to. */ + private final BlockFetchingListener listener; + + /** Max number of times we are allowed to retry. */ + private final int maxRetries; + + /** Milliseconds to wait before each retry. */ + private final int retryWaitTime; + + // NOTE: + // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated + // while inside a synchronized block. + /** Number of times we've attempted to retry so far. */ + private int retryCount = 0; + + /** + * Set of all block ids which have not been fetched successfully or with a non-IO Exception. + * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, + * input ordering is preserved, so we always request blocks in the same order the user provided. + */ + private final LinkedHashSet outstandingBlocksIds; + + /** + * The BlockFetchingListener that is active with our current BlockFetcher. + * When we start a retry, we immediately replace this with a new Listener, which causes all any + * old Listeners to ignore all further responses. + */ + private RetryingBlockFetchListener currentListener; + + public RetryingBlockFetcher( + TransportConf conf, + BlockFetchStarter fetchStarter, + String[] blockIds, + BlockFetchingListener listener) { + this.fetchStarter = fetchStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTime(); + this.outstandingBlocksIds = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlocksIds, blockIds); + this.currentListener = new RetryingBlockFetchListener(); + } + + /** + * Initiates the fetch of all blocks provided in the constructor, with possible retries in the + * event of transient IOExceptions. + */ + public void start() { + fetchAllOutstanding(); + } + + /** + * Fires off a request to fetch all blocks that have not been fetched successfully or permanently + * failed (i.e., by a non-IOException). + */ + private void fetchAllOutstanding() { + // Start by retrieving our shared state within a synchronized block. + String[] blockIdsToFetch; + int numRetries; + RetryingBlockFetchListener myListener; + synchronized (this) { + blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. + try { + fetchStarter.createAndStart(blockIdsToFetch, myListener); + } catch (Exception e) { + logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", + blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); + + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid : blockIdsToFetch) { + listener.onBlockFetchFailure(bid, e); + } + } + } + } + + /** + * Lightweight method which initiates a retry in a different thread. The retry will involve + * calling fetchAllOutstanding() after a configured wait time. + */ + private synchronized void initiateRetry() { + retryCount += 1; + currentListener = new RetryingBlockFetchListener(); + + logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", + retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); + } + }); + } + + /** + * Returns true if we should retry due a block fetch failure. We will retry if and only if + * the exception was an IOException and we haven't retried 'maxRetries' times already. + */ + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null && e.getCause() instanceof IOException); + boolean hasRemainingRetries = retryCount < maxRetries; + return isIOException && hasRemainingRetries; + } + + /** + * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. + * Note that in the event of a retry, we will immediately replace the 'currentListener' field, + * indicating that any responses from non-current Listeners should be ignored. + */ + private class RetryingBlockFetchListener implements BlockFetchingListener { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data); + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + // We will only forward this failure to our parent listener if this block request is + // outstanding, we are still the active listener, AND we cannot retry the fetch. + boolean shouldForwardFailure = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", + blockId, retryCount), exception); + outstandingBlocksIds.remove(blockId); + shouldForwardFailure = true; + } + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardFailure) { + listener.onBlockFetchFailure(blockId, exception); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java new file mode 100644 index 0000000000000..f72ab40690d0d --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Closeable; + +/** Provides an interface for reading shuffle files, either from an Executor or external service. */ +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + + /** + * Fetch a sequence of blocks from a remote node asynchronously, + * + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. + */ + public abstract void fetchBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java new file mode 100644 index 0000000000000..b4b13b8a6ef5d --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or + * by Spark's NettyBlockTransferService. + * + * At a high level: + * - OpenBlock is handled by both services, but only services shuffle files for the external + * shuffle service. It returns a StreamHandle. + * - UploadBlock is only handled by the NettyBlockTransferService. + * - RegisterExecutor is only handled by the external shuffle service. + */ +public abstract class BlockTransferMessage implements Encodable { + protected abstract Type type(); + + /** Preceding every serialized message is its type, which allows us to deserialize it. */ + public static enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + } + + // NB: Java does not support static methods in interfaces, so we must put this in a static class. + public static class Decoder { + /** Deserializes the 'type' byte followed by the message itself. */ + public static BlockTransferMessage fromByteArray(byte[] msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return OpenBlocks.decode(buf); + case 1: return UploadBlock.decode(buf); + case 2: return RegisterExecutor.decode(buf); + case 3: return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + /** Serializes the 'type' byte followed by the message itself. */ + public byte[] toByteArray() { + ByteBuf buf = Unpooled.buffer(encodedLength()); + buf.writeByte(type().id); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.array(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java new file mode 100644 index 0000000000000..cadc8e8369c6a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** Contains all configuration necessary for locating the shuffle files of an executor. */ +public class ExecutorShuffleInfo implements Encodable { + /** The base set of local directories that the executor stores its shuffle files in. */ + public final String[] localDirs; + /** Number of subdirectories created within each localDir. */ + public final int subDirsPerLocalDir; + /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ + public final String shuffleManager; + + public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + this.localDirs = localDirs; + this.subDirsPerLocalDir = subDirsPerLocalDir; + this.shuffleManager = shuffleManager; + } + + @Override + public int hashCode() { + return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("localDirs", Arrays.toString(localDirs)) + .add("subDirsPerLocalDir", subDirsPerLocalDir) + .add("shuffleManager", shuffleManager) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof ExecutorShuffleInfo) { + ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; + return Arrays.equals(localDirs, o.localDirs) + && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(localDirs) + + 4 // int + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, localDirs); + buf.writeInt(subDirsPerLocalDir); + Encoders.Strings.encode(buf, shuffleManager); + } + + public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] localDirs = Encoders.StringArrays.decode(buf); + int subDirsPerLocalDir = buf.readInt(); + String shuffleManager = Encoders.Strings.decode(buf); + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java new file mode 100644 index 0000000000000..60485bace643c --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class OpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static OpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java new file mode 100644 index 0000000000000..38acae3b31d64 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Initial registration message between an executor and its local shuffle server. + * Returns nothing (empty bye array). + */ +public class RegisterExecutor extends BlockTransferMessage { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + protected Type type() { return Type.REGISTER_EXECUTOR; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + executorInfo.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + executorInfo.encode(buf); + } + + public static RegisterExecutor decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); + return new RegisterExecutor(appId, execId, executorShuffleInfo); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java new file mode 100644 index 0000000000000..21369c8cfb0d6 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.io.Serializable; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" + * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. + */ +public class StreamHandle extends BlockTransferMessage { + public final long streamId; + public final int numChunks; + + public StreamHandle(long streamId, int numChunks) { + this.streamId = streamId; + this.numChunks = numChunks; + } + + @Override + protected Type type() { return Type.STREAM_HANDLE; } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, numChunks); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("numChunks", numChunks) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; + return Objects.equal(streamId, o.streamId) + && Objects.equal(numChunks, o.numChunks); + } + return false; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(numChunks); + } + + public static StreamHandle decode(ByteBuf buf) { + long streamId = buf.readLong(); + int numChunks = buf.readInt(); + return new StreamHandle(streamId, numChunks); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java new file mode 100644 index 0000000000000..38abe29cc585f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ +public class UploadBlock extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String blockId; + // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in + // this package. We should avoid this hack. + public final byte[] metadata; + public final byte[] blockData; + + /** + * @param metadata Meta-information about block, typically StorageLevel. + * @param blockData The actual block's bytes. + */ + public UploadBlock( + String appId, + String execId, + String blockId, + byte[] metadata, + byte[] blockData) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.metadata = metadata; + this.blockData = blockData; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(appId, execId, blockId); + return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .add("block size", blockData.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlock) { + UploadBlock o = (UploadBlock) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata) + && Arrays.equals(blockData, o.blockData); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata) + + Encoders.ByteArrays.encodedLength(blockData); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + Encoders.ByteArrays.encode(buf, blockData); + } + + public static UploadBlock decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + byte[] blockData = Encoders.ByteArrays.decode(buf); + return new UploadBlock(appId, execId, blockId, metadata, blockData); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000000..d25283e46ef96 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000000..67a07f38eb5a0 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java new file mode 100644 index 0000000000000..d65de9ca550a3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.shuffle.protocol.*; + +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { + @Test + public void serializeOpenShuffleBlocks() { + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); + } + + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java new file mode 100644 index 0000000000000..3f9fe1681cf27 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; + +public class ExternalShuffleBlockHandlerSuite { + TransportClient client = mock(TransportClient.class); + + OneForOneStreamManager streamManager; + ExternalShuffleBlockManager blockManager; + RpcHandler handler; + + @Before + public void beforeEach() { + streamManager = mock(OneForOneStreamManager.class); + blockManager = mock(ExternalShuffleBlockManager.class); + handler = new ExternalShuffleBlockHandler(streamManager, blockManager); + } + + @Test + public void testRegisterExecutor() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); + byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + handler.receive(client, registerMessage, callback); + verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); + + verify(callback, times(1)).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } + + @SuppressWarnings("unchecked") + @Test + public void testOpenShuffleBlocks() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); + ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); + when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); + byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + handler.receive(client, openBlocks, callback); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); + + ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure((Throwable) any()); + + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + assertEquals(2, handle.numChunks); + + ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(stream.capture()); + Iterator buffers = (Iterator) stream.getValue(); + assertEquals(block0Marker, buffers.next()); + assertEquals(block1Marker, buffers.next()); + assertFalse(buffers.hasNext()); + } + + @Test + public void testBadMessages() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + try { + handler.receive(client, unserializableMsg, callback); + fail("Should have thrown"); + } catch (Exception e) { + // pass + } + + byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + try { + handler.receive(client, unexpectedMsg, callback); + fail("Should have thrown"); + } catch (UnsupportedOperationException e) { + // pass + } + + verify(callback, never()).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java new file mode 100644 index 0000000000000..dad6428a836fc --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.google.common.io.CharStreams; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExternalShuffleBlockManagerSuite { + static String sortBlock0 = "Hello!"; + static String sortBlock1 = "World!"; + + static String hashBlock0 = "Elementary"; + static String hashBlock1 = "Tabular"; + + static TestShuffleDataContext dataContext; + + static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + @BeforeClass + public static void beforeAll() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort and hash data. + dataContext.insertSortShuffleData(0, 0, + new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); + dataContext.insertHashShuffleData(1, 0, + new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void testBadRequests() { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); + // Unregistered executor + try { + manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (RuntimeException e) { + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); + } + + // Invalid shuffle manager + manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); + try { + manager.getBlockData("app0", "exec2", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (UnsupportedOperationException e) { + // pass + } + + // Nonexistent shuffle block + manager.registerExecutor("app0", "exec3", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + try { + manager.getBlockData("app0", "exec3", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (Exception e) { + // pass + } + } + + @Test + public void testSortShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(sortBlock1, block1); + } + + @Test + public void testHashShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(hashBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(hashBlock1, block1); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 0000000000000..254e3a7a32b98 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + + manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + manager.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java new file mode 100644 index 0000000000000..02c10bcb7b261 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleIntegrationSuite { + + static String APP_ID = "app-id"; + static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + + // Executor 0 is sort-based + static TestShuffleDataContext dataContext0; + // Executor 1 is hash-based + static TestShuffleDataContext dataContext1; + + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + + static byte[][] exec0Blocks = new byte[][] { + new byte[123], + new byte[12345], + new byte[1234567], + }; + + static byte[][] exec1Blocks = new byte[][] { + new byte[321], + new byte[54321], + }; + + @BeforeClass + public static void beforeAll() throws IOException { + Random rand = new Random(); + + for (byte[] block : exec0Blocks) { + rand.nextBytes(block); + } + for (byte[] block: exec1Blocks) { + rand.nextBytes(block); + } + + dataContext0 = new TestShuffleDataContext(2, 5); + dataContext0.create(); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + + dataContext1 = new TestShuffleDataContext(6, 2); + dataContext1.create(); + dataContext1.insertHashShuffleData(1, 0, exec1Blocks); + + conf = new TransportConf(new SystemPropertyConfigProvider()); + handler = new ExternalShuffleBlockHandler(conf); + TransportContext transportContext = new TransportContext(conf, handler); + server = transportContext.createServer(); + } + + @AfterClass + public static void afterAll() { + dataContext0.cleanup(); + dataContext1.cleanup(); + server.close(); + } + + @After + public void afterEach() { + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); + } + + class FetchResult { + public Set successBlocks; + public Set failedBlocks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + // Fetch a set of blocks from a pre-registered executor. + private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { + return fetchBlocks(execId, blockIds, server.getPort()); + } + + // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, + // to allow connecting to invalid servers. + private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + final FetchResult res = new FetchResult(); + res.successBlocks = Collections.synchronizedSet(new HashSet()); + res.failedBlocks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + final Semaphore requestsRemaining = new Semaphore(0); + + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + client.init(APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } + } + } + }); + + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void testFetchOneSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchHash() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); + assertTrue(execFetch.failedBlocks.isEmpty()); + assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); + execFetch.releaseBuffers(); + } + + @Test + public void testFetchWrongShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + @Test + public void testFetchInvalidShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongBlockId() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "rdd_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNonexistent() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_2_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); + // Both still fail, as we start by checking for all block. + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchUnregisteredExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-2", + new String[] { "shuffle_0_0_0", "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNoServer() throws Exception { + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } + } + + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + client.init(APP_ID); + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), + executorId, executorInfo); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i)))); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000000..759a12910c94d --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() throws IOException { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) throws IOException { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java new file mode 100644 index 0000000000000..842741e3d354f --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.collect.Maps; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; + +public class OneForOneBlockFetcherSuite { + @Test + public void testFetchOne() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + } + + @Test + public void testFetchThree() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); + } + } + + @Test + public void testFailure() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", null); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // Each failure will cause a failure to be invoked in all remaining block fetches. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testFailureAndSuccess() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // We may call both success and failure for the same block. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testEmptyBlockFetch() { + try { + fetchBlocks(Maps.newLinkedHashMap()); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Zero-sized blockIds array", e.getMessage()); + } + } + + /** + * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which + * simply returns the given (BlockId, Block) pairs. + * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order + * that they were inserted in. + * + * If a block's buffer is "null", an exception will be thrown instead. + */ + private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { + TransportClient client = mock(TransportClient.class); + BlockFetchingListener listener = mock(BlockFetchingListener.class); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); + + // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( + (byte[]) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); + return null; + } + }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + + // Respond to each chunk request with a single buffer from our blocks array. + final AtomicInteger expectedChunkIndex = new AtomicInteger(0); + final Iterator blockIterator = blocks.values().iterator(); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); + } + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); + } + return null; + } + }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + + fetcher.start(); + return listener; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 0000000000000..0191fe529e1be --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWaitMs", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWaitMs"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java new file mode 100644 index 0000000000000..76639114df5d9 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import com.google.common.io.Files; + +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + +/** + * Manages some sort- and hash-based shuffle data, including the creation + * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. + */ +public class TestShuffleDataContext { + public final String[] localDirs; + public final int subDirsPerLocalDir; + + public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { + this.localDirs = new String[numLocalDirs]; + this.subDirsPerLocalDir = subDirsPerLocalDir; + } + + public void create() { + for (int i = 0; i < localDirs.length; i ++) { + localDirs[i] = Files.createTempDir().getAbsolutePath(); + + for (int p = 0; p < subDirsPerLocalDir; p ++) { + new File(localDirs[i], String.format("%02x", p)).mkdirs(); + } + } + } + + public void cleanup() { + for (String localDir : localDirs) { + deleteRecursively(new File(localDir)); + } + } + + /** Creates reducer blocks in a sort-based data format within our local dirs. */ + public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + + OutputStream dataStream = new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; + indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + + dataStream.close(); + indexStream.close(); + } + + /** Creates reducer blocks in a hash-based data format within our local dirs. */ + public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + for (int i = 0; i < blocks.length; i ++) { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; + Files.write(blocks[i], + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId)); + } + } + + /** + * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this + * context's directories. + */ + public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) { + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } + + private static void deleteRecursively(File f) { + assert f != null; + if (f.isDirectory()) { + File[] children = f.listFiles(); + if (children != null) { + for (File child : children) { + deleteRecursively(child); + } + } + } + f.delete(); + } +} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml new file mode 100644 index 0000000000000..6e6f6f3e79296 --- /dev/null +++ b/network/yarn/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.10 + jar + Spark Project YARN Shuffle Service + http://spark.apache.org/ + + network-yarn + + + + + + org.apache.spark + spark-network-shuffle_2.10 + ${project.version} + + + + + org.apache.hadoop + hadoop-client + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java new file mode 100644 index 0000000000000..a34aabe9e78a6 --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.lang.Override; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; +import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.ShuffleSecretManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.yarn.util.HadoopConfigProvider; + +/** + * An external shuffle service used by Spark on Yarn. + * + * This is intended to be a long-running auxiliary service that runs in the NodeManager process. + * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. + * The application also automatically derives the service port through `spark.shuffle.service.port` + * specified in the Yarn configuration. This is so that both the clients and the server agree on + * the same port to communicate on. + * + * The service also optionally supports authentication. This ensures that executors from one + * application cannot read the shuffle files written by those from another. This feature can be + * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. + * Note that the Spark application must also set `spark.authenticate` manually and, unlike in + * the case of the service port, will not inherit this setting from the Yarn configuration. This + * is because an application running on the same Yarn cluster may choose to not use the external + * shuffle service, in which case its setting of `spark.authenticate` should be independent of + * the service's. + */ +public class YarnShuffleService extends AuxiliaryService { + private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); + + // Port on which the shuffle server listens for fetch requests + private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; + private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; + + // Whether the shuffle server should authenticate fetch requests + private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; + private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + + // An entity that manages the shuffle secret per application + // This is used only if authentication is enabled + private ShuffleSecretManager secretManager; + + // The actual server that serves shuffle files + private TransportServer shuffleServer = null; + + public YarnShuffleService() { + super("spark_shuffle"); + logger.info("Initializing YARN shuffle service for Spark"); + } + + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + private boolean isAuthenticationEnabled() { + return secretManager != null; + } + + /** + * Start the shuffle server with the given configuration. + */ + @Override + protected void serviceInit(Configuration conf) { + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); + if (authEnabled) { + secretManager = new ShuffleSecretManager(); + rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportContext transportContext = new TransportContext(transportConf, rpcHandler); + shuffleServer = transportContext.createServer(port); + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}.", port, authEnabledString); + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + String appId = context.getApplicationId().toString(); + try { + ByteBuffer shuffleSecret = context.getApplicationDataForService(); + logger.info("Initializing application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.registerApp(appId, shuffleSecret); + } + } catch (Exception e) { + logger.error("Exception when initializing application {}", appId, e); + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + String appId = context.getApplicationId().toString(); + try { + logger.info("Stopping application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.unregisterApp(appId); + } + } catch (Exception e) { + logger.error("Exception when stopping application {}", appId, e); + } + } + + @Override + public void initializeContainer(ContainerInitializationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Initializing container {}", containerId); + } + + @Override + public void stopContainer(ContainerTerminationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Stopping container {}", containerId); + } + + /** + * Close the shuffle server to clean up any associated state. + */ + @Override + protected void serviceStop() { + try { + if (shuffleServer != null) { + shuffleServer.close(); + } + } catch (Exception e) { + logger.error("Exception when stopping service", e); + } + } + + // Not currently used + @Override + public ByteBuffer getMetaData() { + return ByteBuffer.allocate(0); + } + +} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java new file mode 100644 index 0000000000000..884861752e80d --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn.util; + +import java.util.NoSuchElementException; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.spark.network.util.ConfigProvider; + +/** Use the Hadoop configuration to obtain config values. */ +public class HadoopConfigProvider extends ConfigProvider { + private final Configuration conf; + + public HadoopConfigProvider(Configuration conf) { + this.conf = conf; + } + + @Override + public String get(String name) { + String value = conf.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/pom.xml b/pom.xml index e4c92470fc03e..4e0cd6c151d0b 100644 --- a/pom.xml +++ b/pom.xml @@ -92,6 +92,7 @@ mllib tools network/common + network/shuffle streaming sql/catalyst sql/core @@ -129,11 +130,11 @@ 1.4.0 3.4.5 - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 - 1.4.3 + 1.6.0rc3 1.2.3 8.1.14.v20131031 0.3.6 @@ -144,6 +145,7 @@ 1.8.3 1.1.0 4.2.6 + 3.1.1 64m 512m @@ -240,6 +242,18 @@ false + + + spark-staging-hive13 + Spring Staging Repository Hive 13 + https://oss.sonatype.org/content/repositories/orgspark-project-1089/ + + true + + + false + + @@ -305,7 +319,7 @@ org.apache.commons commons-math3 - 3.3 + ${commons.math3.version} com.google.code.findbugs @@ -446,7 +460,7 @@ org.roaringbitmap RoaringBitmap - 0.4.3 + 0.4.5 commons-net @@ -908,9 +922,9 @@ by Spark SQL for code generation. --> - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} @@ -1162,6 +1176,10 @@ + + hadoop-0.23 @@ -1191,6 +1209,7 @@ 2.3.0 2.5.0 0.9.0 + 3.1.1 hadoop2 @@ -1201,6 +1220,7 @@ 2.4.0 2.5.0 0.9.0 + 3.1.1 hadoop2 @@ -1216,6 +1236,7 @@ yarn yarn + network/yarn @@ -1314,14 +1335,19 @@ - hive-0.12.0 + hive false - sql/hive-thriftserver + + + hive-0.12.0 + + false + 0.12.0-protobuf-2.5 0.12.0 @@ -1334,7 +1360,7 @@ false - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6a0495f8fd540..a94d09be3bec6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,6 +77,14 @@ object MimaExcludes { // SPARK-3822 ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") ) case v if v.startsWith("1.1") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6d5eb681c6131..657e4b4432775 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -31,15 +31,16 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, - streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", - "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", + "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", + "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") - .map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) @@ -142,7 +143,9 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) + streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { + x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) + } /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) diff --git a/python/docs/make.bat b/python/docs/make.bat index c011e82b4a35a..cc29acdc19686 100644 --- a/python/docs/make.bat +++ b/python/docs/make.bat @@ -1,6 +1,6 @@ -@ECHO OFF - -rem This is the entry point for running Sphinx documentation. To avoid polluting the -rem environment, it just launches a new cmd to do the real work. - -cmd /V /E /C %~dp0make2.bat %* +@ECHO OFF + +rem This is the entry point for running Sphinx documentation. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0make2.bat %* diff --git a/python/docs/make2.bat b/python/docs/make2.bat index 7bcaeafad13d7..05d22eb5cdd23 100644 --- a/python/docs/make2.bat +++ b/python/docs/make2.bat @@ -1,243 +1,243 @@ -@ECHO OFF - -REM Command file for Sphinx documentation - - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end +@ECHO OFF + +REM Command file for Sphinx documentation + + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5f8dcedb1eea2..faa5952258aef 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -63,7 +63,6 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, @@ -115,9 +114,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer - if batchSize == 1: - self.serializer = self._unbatched_serializer - elif batchSize == 0: + if batchSize == 0: self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, @@ -305,12 +302,8 @@ def parallelize(self, c, numSlices=None): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._unbatched_serializer, - batchSize) - else: - serializer = self._unbatched_serializer + batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile @@ -328,8 +321,7 @@ def pickleFile(self, name, minPartitions=None): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ minPartitions = minPartitions or self.defaultMinPartitions - return RDD(self._jsc.objectFile(name, minPartitions), self, - BatchedSerializer(PickleSerializer())) + return RDD(self._jsc.objectFile(name, minPartitions), self) def textFile(self, name, minPartitions=None, use_unicode=True): """ @@ -396,6 +388,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) + def binaryFiles(self, path, minPartitions=None): + """ + :: Experimental :: + + Read a directory of binary files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system URI + as a byte array. Each file is read as a single record and returned + in a key-value pair, where the key is the path of each file, the + value is the content of each file. + + Note: Small files are preferred, large file is also allowable, but + may cause bad performance. + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.binaryFiles(path, minPartitions), self, + PairDeserializer(UTF8Deserializer(), NoOpSerializer())) + + def binaryRecords(self, path, recordLength): + """ + :: Experimental :: + + Load data from a flat binary file, assuming each record is a set of numbers + with the specified numerical format (see ByteBuffer), and the number of + bytes per record is constant. + + :param path: Directory to the input data files + :param recordLength: The length at which to split the records + """ + return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer()) + def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: @@ -405,7 +427,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None, batchSize=None): + valueConverter=None, minSplits=None, batchSize=0): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -427,17 +449,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ minSplits = minSplits or min(self.defaultParallelism, 2) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, keyConverter, valueConverter, minSplits, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -458,18 +478,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -487,18 +505,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -519,18 +535,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -548,15 +562,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) @@ -836,7 +848,7 @@ def _test(): import doctest import tempfile globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e295c9d0954d9..5d90dddb5df1c 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -20,8 +20,8 @@ import numpy from numpy import array -from pyspark import SparkContext, PickleSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -62,6 +62,7 @@ class LogisticRegressionModel(LinearModel): """ def predict(self, x): + x = _convert_to_vector(x) margin = self.weights.dot(x) + self._intercept if margin > 0: prob = 1 / (1 + exp(-margin)) @@ -79,7 +80,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). @@ -102,14 +103,11 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, training data (i.e. whether bias features are activated or not). """ - sc = data.context - - def train(jdata, i): - return sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - jdata, iterations, step, miniBatchFraction, i, regParam, regType, intercept) + def train(rdd, i): + return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step, + miniBatchFraction, i, regParam, regType, intercept) - return _regression_train_wrapper(sc, train, LogisticRegressionModel, data, - initialWeights) + return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) class SVMModel(LinearModel): @@ -139,6 +137,7 @@ class SVMModel(LinearModel): """ def predict(self, x): + x = _convert_to_vector(x) margin = self.weights.dot(x) + self.intercept return 1 if margin >= 0 else 0 @@ -151,7 +150,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, """ Train a support vector machine on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). @@ -174,13 +173,11 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, training data (i.e. whether bias features are activated or not). """ - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) + def train(rdd, i): + return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam, + miniBatchFraction, i, regType, intercept) - return _regression_train_wrapper(sc, train, SVMModel, data, initialWeights) + return _regression_train_wrapper(train, SVMModel, data, initialWeights) class NaiveBayesModel(object): @@ -238,19 +235,19 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector - (e.g. a count vector). + :param data: RDD of LabeledPoint. :param lambda_: The smoothing parameter """ - sc = data.context - jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_) - labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("`data` should be an RDD of LabeledPoint") + labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) def _test(): import doctest + from pyspark import SparkContext globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 5ee7997104d21..fe4c4cc5094d8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -16,8 +16,8 @@ # from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd +from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -80,14 +80,11 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - sc = rdd.context - ser = PickleSerializer() # cache serialized data to avoid objects over head in JVM - cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() - model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode) - bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) - centers = ser.loads(str(bytes)) + jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True) + model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs, + initializationMode) + centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py new file mode 100644 index 0000000000000..c6149fe391ec8 --- /dev/null +++ b/python/pyspark/mllib/common.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j.protocol +from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaObject +from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList + +from pyspark import RDD, SparkContext +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + + +# Hack for support float('inf') in Py4j +_old_smart_decode = py4j.protocol.smart_decode + +_float_str_mapping = { + 'nan': 'NaN', + 'inf': 'Infinity', + '-inf': '-Infinity', +} + + +def _new_smart_decode(obj): + if isinstance(obj, float): + s = unicode(obj) + return _float_str_mapping.get(s, s) + return _old_smart_decode(obj) + +py4j.protocol.smart_decode = _new_smart_decode + + +_picklable_classes = [ + 'LinkedList', + 'SparseVector', + 'DenseVector', + 'DenseMatrix', + 'Rating', + 'LabeledPoint', +] + + +# this will call the MLlib version of pythonToJava() +def _to_java_object_rdd(rdd, cache=False): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + if cache: + rdd.cache() + return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + + +def _py2java(sc, obj): + """ Convert Python object into Java """ + if isinstance(obj, RDD): + obj = _to_java_object_rdd(obj) + elif isinstance(obj, SparkContext): + obj = obj._jsc + elif isinstance(obj, dict): + obj = MapConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, (list, tuple)): + obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, JavaObject): + pass + elif isinstance(obj, (int, long, float, bool, basestring)): + pass + else: + bytes = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(bytes) + return obj + + +def _java2py(sc, r): + if isinstance(r, JavaObject): + clsName = r.getClass().getSimpleName() + # convert RDD into JavaRDD + if clsName != 'JavaRDD' and clsName.endswith("RDD"): + r = r.toJavaRDD() + clsName = 'JavaRDD' + + if clsName == 'JavaRDD': + jrdd = sc._jvm.SerDe.javaToPython(r) + return RDD(jrdd, sc) + + if clsName in _picklable_classes: + r = sc._jvm.SerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.SerDe.dumps(r) + except Py4JJavaError: + pass # not pickable + + if isinstance(r, bytearray): + r = PickleSerializer().loads(str(r)) + return r + + +def callJavaFunc(sc, func, *args): + """ Call Java Function """ + args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*args)) + + +def callMLlibFunc(name, *args): + """ Call API in PythonMLLibAPI """ + sc = SparkContext._active_spark_context + api = getattr(sc._jvm.PythonMLLibAPI(), name) + return callJavaFunc(sc, api, *args) + + +class JavaModelWrapper(object): + """ + Wrapper for the model in JVM + """ + def __init__(self, java_model): + self._sc = SparkContext._active_spark_context + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def call(self, name, *a): + """Call method of java_model""" + return callJavaFunc(self._sc, getattr(self._java_model, name), *a) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 324343443ebdb..9ec28079aef43 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -21,89 +21,16 @@ import sys import warnings -import py4j.protocol from py4j.protocol import Py4JJavaError -from py4j.java_gateway import JavaObject from pyspark import RDD, SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import Vectors, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] -# Hack for support float('inf') in Py4j -_old_smart_decode = py4j.protocol.smart_decode - -_float_str_mapping = { - u'nan': u'NaN', - u'inf': u'Infinity', - u'-inf': u'-Infinity', -} - - -def _new_smart_decode(obj): - if isinstance(obj, float): - s = unicode(obj) - return _float_str_mapping.get(s, s) - return _old_smart_decode(obj) - -py4j.protocol.smart_decode = _new_smart_decode - - -# TODO: move these helper functions into utils -_picklable_classes = [ - 'LinkedList', - 'SparseVector', - 'DenseVector', - 'DenseMatrix', - 'Rating', - 'LabeledPoint', -] - - -def _py2java(sc, a): - """ Convert Python object into Java """ - if isinstance(a, RDD): - a = _to_java_object_rdd(a) - elif not isinstance(a, (int, long, float, bool, basestring)): - bytes = bytearray(PickleSerializer().dumps(a)) - a = sc._jvm.SerDe.loads(bytes) - return a - - -def _java2py(sc, r): - if isinstance(r, JavaObject): - clsName = r.getClass().getSimpleName() - if clsName in ("RDD", "JavaRDD"): - if clsName == "RDD": - r = r.toJavaRDD() - jrdd = sc._jvm.SerDe.javaToPython(r) - return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer())) - - elif clsName in _picklable_classes: - r = sc._jvm.SerDe.dumps(r) - - if isinstance(r, bytearray): - r = PickleSerializer().loads(str(r)) - return r - - -def _callJavaFunc(sc, func, *args): - """ Call Java Function - """ - args = [_py2java(sc, a) for a in args] - return _java2py(sc, func(*args)) - - -def _callAPI(sc, name, *args): - """ Call API in PythonMLLibAPI - """ - api = getattr(sc._jvm.PythonMLLibAPI(), name) - return _callJavaFunc(sc, api, *args) - - class VectorTransformer(object): """ :: DeveloperApi :: @@ -154,31 +81,33 @@ def transform(self, vector): """ Applies unit length normalization on a vector. - :param vector: vector to be normalized. + :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" - return _callAPI(sc, "normalizeVector", self.p, vector) + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return callMLlibFunc("normalizeVector", self.p, vector) -class JavaModelWrapper(VectorTransformer): +class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ Wrapper for the model in JVM """ - def __init__(self, sc, java_model): - self._sc = sc - self._java_model = java_model - - def __del__(self): - self._sc._gateway.detach(self._java_model) - def transform(self, dataset): - return _callJavaFunc(self._sc, self._java_model.transform, dataset) + def transform(self, vector): + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return self.call("transform", vector) -class StandardScalerModel(JavaModelWrapper): +class StandardScalerModel(JavaVectorTransformer): """ :: Experimental :: @@ -188,11 +117,11 @@ def transform(self, vector): """ Applies standardization transformation on a vector. - :param vector: Vector to be standardized. + :param vector: Vector or RDD of Vector to be standardized. :return: Standardized vector. If the variance of a column is zero, it will return default `0.0` for the column with zero variance. """ - return JavaModelWrapper.transform(self, vector) + return JavaVectorTransformer.transform(self, vector) class StandardScaler(object): @@ -233,9 +162,9 @@ def fit(self, dataset): the transformation model. :return: a StandardScalarModel """ - sc = dataset.context - jmodel = _callAPI(sc, "fitStandardScaler", self.withMean, self.withStd, dataset) - return StandardScalerModel(sc, jmodel) + dataset = dataset.map(_convert_to_vector) + jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) + return StandardScalerModel(jmodel) class HashingTF(object): @@ -276,7 +205,7 @@ def transform(self, document): return Vectors.sparse(self.numFeatures, freq.items()) -class IDFModel(JavaModelWrapper): +class IDFModel(JavaVectorTransformer): """ Represents an IDF model that can transform term frequency vectors. """ @@ -291,7 +220,9 @@ def transform(self, dataset): :param dataset: an RDD of term frequency vectors :return: an RDD of TF-IDF vectors """ - return JavaModelWrapper.transform(self, dataset) + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + return JavaVectorTransformer.transform(self, dataset) class IDF(object): @@ -335,12 +266,13 @@ def fit(self, dataset): :param dataset: an RDD of term frequency vectors """ - sc = dataset.context - jmodel = _callAPI(sc, "fitIDF", self.minDocFreq, dataset) - return IDFModel(sc, jmodel) + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector)) + return IDFModel(jmodel) -class Word2VecModel(JavaModelWrapper): +class Word2VecModel(JavaVectorTransformer): """ class for Word2Vec model """ @@ -354,7 +286,7 @@ def transform(self, word): :return: vector representation of word(s) """ try: - return _callJavaFunc(self._sc, self._java_model.transform, word) + return self.call("transform", word) except Py4JJavaError: raise ValueError("%s not found" % word) @@ -368,7 +300,9 @@ def findSynonyms(self, word, num): Note: local use only """ - words, similarity = _callJavaFunc(self._sc, self._java_model.findSynonyms, word, num) + if not isinstance(word, basestring): + word = _convert_to_vector(word) + words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) @@ -455,14 +389,15 @@ def fit(self, data): """ Computes the vector representation of each word in vocabulary. - :param data: training data. RDD of subtype of Iterable[String] + :param data: training data. RDD of list of string :return: Word2VecModel instance """ - sc = data.context - jmodel = _callAPI(sc, "trainWord2Vec", data, int(self.vectorSize), - float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) - return Word2VecModel(sc, jmodel) + if not isinstance(data, RDD): + raise TypeError("data should be an RDD of list of string") + jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + float(self.learningRate), int(self.numPartitions), + int(self.numIterations), long(self.seed)) + return Word2VecModel(jmodel) def _test(): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 1b9bf596242df..e35202dca0acc 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,9 +29,11 @@ import numpy as np -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] + +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): @@ -52,17 +54,6 @@ def fast_pickle_array(ar): _have_scipy = False -# this will call the MLlib version of pythonToJava() -def _to_java_object_rdd(rdd): - """ Return an JavaRDD of Object by unpickling - - It will convert each Python object into Java object by Pyrolite, whenever the - RDD is serialized in batch or not. - """ - rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) - - def _convert_to_vector(l): if isinstance(l, Vector): return l @@ -118,7 +109,54 @@ def _format_float(f, digits=4): return s +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ @@ -540,6 +578,8 @@ class DenseMatrix(Matrix): def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) assert len(values) == numRows * numCols + if not isinstance(values, array.array): + values = array.array('d', values) self.values = values def __reduce__(self): @@ -558,6 +598,15 @@ def toArray(self): return np.reshape(self.values, (self.numRows, self.numCols), order='F') +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 2202c51ab9c06..cb4304f92152b 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -21,22 +21,12 @@ from functools import wraps -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.mllib.common import callMLlibFunc __all__ = ['RandomRDDs', ] -def serialize(f): - @wraps(f) - def func(sc, *a, **kw): - jrdd = f(sc, *a, **kw) - return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc, - BatchedSerializer(PickleSerializer(), 1024)) - return func - - def toArray(f): @wraps(f) def func(sc, *a, **kw): @@ -52,7 +42,6 @@ class RandomRDDs(object): """ @staticmethod - @serialize def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -63,6 +52,12 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 @@ -74,10 +69,9 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): >>> parts == sc.defaultParallelism True """ - return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) + return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod - @serialize def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -88,6 +82,12 @@ def normalRDD(sc, size, numPartitions=None, seed=None): C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() @@ -97,15 +97,21 @@ def normalRDD(sc, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - 1.0) < 0.1 True """ - return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) + return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod - @serialize def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). + >>> mean = 100.0 >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) >>> stats = x.stats() @@ -117,16 +123,22 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - sqrt(mean)) < 0.5 True """ - return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) + return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod @toArray - @serialize def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the uniform distribution U(0.0, 1.0). + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD. + :param seed: Seed for the RNG that generates the seed for the generator in each partition. + :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + >>> import numpy as np >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape @@ -136,17 +148,22 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ - return sc._jvm.PythonMLLibAPI() \ - .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) @staticmethod @toArray - @serialize def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + >>> import numpy as np >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape @@ -156,17 +173,23 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - 1.0) < 0.1 True """ - return sc._jvm.PythonMLLibAPI() \ - .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) @staticmethod @toArray - @serialize def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + >>> import numpy as np >>> mean = 100.0 >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) @@ -179,8 +202,8 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - sqrt(mean)) < 0.5 True """ - return sc._jvm.PythonMLLibAPI() \ - .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) + return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, + numPartitions, seed) def _test(): diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 22872dbbe3b55..41bbd9a779c70 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,9 +16,8 @@ # from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD -from pyspark.mllib.linalg import _to_java_object_rdd +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd __all__ = ['MatrixFactorizationModel', 'ALS'] @@ -33,10 +32,10 @@ def __reduce__(self): return Rating, (self.user, self.product, self.rating) def __repr__(self): - return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating) + return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating) -class MatrixFactorizationModel(object): +class MatrixFactorizationModel(JavaModelWrapper): """A matrix factorisation model trained by regularized alternating least-squares. @@ -45,74 +44,55 @@ class MatrixFactorizationModel(object): >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) - >>> model = ALS.trainImplicit(ratings, 1) - >>> model.predict(2,2) is not None - True + >>> model = ALS.trainImplicit(ratings, 1, seed=10) + >>> model.predict(2,2) + 0.4473... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1) - >>> model.predictAll(testset).count() == 2 - True + >>> model = ALS.train(ratings, 1, seed=10) + >>> model.predictAll(testset).collect() + [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)] - >>> model = ALS.train(ratings, 4) - >>> model.userFeatures().count() == 2 - True + >>> model = ALS.train(ratings, 4, seed=10) + >>> model.userFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] >>> len(latents) == 4 True - >>> model.productFeatures().count() == 2 - True + >>> model.productFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] >>> len(latents) == 4 True - """ - - def __init__(self, sc, java_model): - self._context = sc - self._java_model = java_model - def __del__(self): - self._context._gateway.detach(self._java_model) + >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 3.735... + >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 0.4473... + """ def predict(self, user, product): - return self._java_model.predict(user, product) + return self._java_model.predict(int(user), int(product)) def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() - if isinstance(first, list): - user_product = user_product.map(tuple) - first = tuple(first) - assert type(first) is tuple and len(first) == 2, \ - "user_product should be RDD of (user, product)" - if any(isinstance(x, str) for x in first): - user_product = user_product.map(lambda (u, p): (int(x), int(p))) - first = tuple(map(int, first)) - assert all(type(x) is int for x in first), "user and product in user_product shoul be int" - sc = self._context - tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd()) - jresult = self._java_model.predict(tuplerdd).toJavaRDD() - return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, - AutoBatchedSerializer(PickleSerializer())) + assert len(first) == 2, "user_product should be RDD of (user, product)" + user_product = user_product.map(lambda (u, p): (int(u), int(p))) + return self.call("predict", user_product) def userFeatures(self): - sc = self._context - juf = self._java_model.userFeatures() - juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc, - AutoBatchedSerializer(PickleSerializer())) + return self.call("getUserFeatures") def productFeatures(self): - sc = self._context - jpf = self._java_model.productFeatures() - jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc, - AutoBatchedSerializer(PickleSerializer())) + return self.call("getProductFeatures") class ALS(object): @@ -126,32 +106,28 @@ def _prepare(cls, ratings): ratings = ratings.map(lambda x: Rating(*x)) else: raise ValueError("rating should be RDD of Rating or tuple/list") - # serialize them by AutoBatchedSerializer before cache to reduce the - # objects overhead in JVM - cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() - return _to_java_object_rdd(cached) + return _to_java_object_rdd(ratings, True) @classmethod - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): - sc = ratings.context - jrating = cls._prepare(ratings) - mod = sc._jvm.PythonMLLibAPI().trainALSModel(jrating, rank, iterations, lambda_, blocks) - return MatrixFactorizationModel(sc, mod) + def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, + seed=None): + model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, + lambda_, blocks, nonnegative, seed) + return MatrixFactorizationModel(model) @classmethod - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): - sc = ratings.context - jrating = cls._prepare(ratings) - mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel( - jrating, rank, iterations, lambda_, blocks, alpha) - return MatrixFactorizationModel(sc, mod) + def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, + nonnegative=False, seed=None): + model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, + iterations, lambda_, blocks, alpha, nonnegative, seed) + return MatrixFactorizationModel(model) def _test(): import doctest import pyspark.mllib.recommendation globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 93e17faf5cd51..66e25a48dfa71 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,9 +18,8 @@ import numpy as np from numpy import array -from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -37,7 +36,7 @@ class LabeledPoint(object): """ def __init__(self, label, features): - self.label = label + self.label = float(label) self.features = _convert_to_vector(features) def __reduce__(self): @@ -47,7 +46,7 @@ def __str__(self): return "(" + ",".join((str(self.label), str(self.features))) + ")" def __repr__(self): - return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")" + return "LabeledPoint(%s, %s)" % (self.label, self.features) class LinearModel(object): @@ -56,7 +55,7 @@ class LinearModel(object): def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) - self._intercept = intercept + self._intercept = float(intercept) @property def weights(self): @@ -67,7 +66,7 @@ def intercept(self): return self._intercept def __repr__(self): - return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) class LinearRegressionModelBase(LinearModel): @@ -86,6 +85,7 @@ def predict(self, x): Predict the value of the dependent variable given a vector x containing values for the independent variables. """ + x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,17 +124,14 @@ class LinearRegressionModel(LinearRegressionModelBase): # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. -def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights): +def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) - ser = PickleSerializer() - initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights))) - # use AutoBatchedSerializer before cache to reduce the memory - # overhead in JVM - cached = data._reserialize(AutoBatchedSerializer(ser)).cache() - ans = train_func(_to_java_object_rdd(cached), initial_bytes) - assert len(ans) == 2, "JVM call result had unexpected length" - weights = ser.loads(str(ans[0])) - return modelClass(weights, ans[1]) + weights, intercept = train_func(_to_java_object_rdd(data, cache=True), + _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): @@ -168,13 +165,12 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, training data (i.e. whether bias features are activated or not). """ - sc = data.context + def train(rdd, i): + return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step, + miniBatchFraction, i, regParam, regType, intercept) - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) - - return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights) + return _regression_train_wrapper(train, LinearRegressionModel, + data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -216,12 +212,10 @@ class LassoWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i) - return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights) + def train(rdd, i): + return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam, + miniBatchFraction, i) + return _regression_train_wrapper(train, LassoModel, data, initialWeights) class RidgeRegressionModel(LinearRegressionModelBase): @@ -263,18 +257,19 @@ class RidgeRegressionWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i) + def train(rdd, i): + return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam, + miniBatchFraction, i) - return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights) + return _regression_train_wrapper(train, RidgeRegressionModel, + data, initialWeights) def _test(): import doctest - globs = globals().copy() + from pyspark import SparkContext + import pyspark.mllib.regression + globs = pyspark.mllib.regression.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 84baf12b906df..1980f5b03f430 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,66 +19,86 @@ Python package for statistical functions in MLlib. """ -from functools import wraps +from pyspark import RDD +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import Matrix, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint -from pyspark import PickleSerializer -from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd +__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] -__all__ = ['MultivariateStatisticalSummary', 'Statistics'] - -def serialize(f): - ser = PickleSerializer() - - @wraps(f) - def func(self): - jvec = f(self) - bytes = self._sc._jvm.SerDe.dumps(jvec) - return ser.loads(str(bytes)).toArray() - - return func - - -class MultivariateStatisticalSummary(object): +class MultivariateStatisticalSummary(JavaModelWrapper): """ Trait for multivariate statistical summary of a data matrix. """ - def __init__(self, sc, java_summary): - """ - :param sc: Spark context - :param java_summary: Handle to Java summary object - """ - self._sc = sc - self._java_summary = java_summary - - def __del__(self): - self._sc._gateway.detach(self._java_summary) - - @serialize def mean(self): - return self._java_summary.mean() + return self.call("mean").toArray() - @serialize def variance(self): - return self._java_summary.variance() + return self.call("variance").toArray() def count(self): - return self._java_summary.count() + return self.call("count") - @serialize def numNonzeros(self): - return self._java_summary.numNonzeros() + return self.call("numNonzeros").toArray() - @serialize def max(self): - return self._java_summary.max() + return self.call("max").toArray() - @serialize def min(self): - return self._java_summary.min() + return self.call("min").toArray() + + +class ChiSqTestResult(JavaModelWrapper): + """ + :: Experimental :: + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() class Statistics(object): @@ -88,6 +108,11 @@ def colStats(rdd): """ Computes column-wise summary statistics for the input RDD[Vector]. + :param rdd: an RDD[Vector] for which column-wise summary statistics + are to be computed. + :return: :class:`MultivariateStatisticalSummary` object containing + column-wise summary statistics. + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), @@ -106,10 +131,8 @@ def colStats(rdd): >>> cStats.min() array([ 2., 0., 0., -2.]) """ - sc = rdd.ctx - jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector)) - cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) - return MultivariateStatisticalSummary(sc, cStats) + cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector)) + return MultivariateStatisticalSummary(cStats) @staticmethod def corr(x, y=None, method=None): @@ -123,6 +146,13 @@ def corr(x, y=None, method=None): to specify the method to be used for single RDD inout. If two RDDs of floats are passed in, a single float is returned. + :param x: an RDD of vector for which the correlation matrix is to be computed, + or an RDD of float of the same cardinality as y when y is specified. + :param y: an RDD of float of the same cardinality as x. + :param method: String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman` + :return: Correlation matrix comparing columns in x. + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) @@ -156,7 +186,6 @@ def corr(x, y=None, method=None): ... except TypeError: ... pass """ - sc = x.ctx # Check inputs to determine whether a single value or a matrix is needed for output. # Since it's legal for users to use the method name as the second argument, we need to # check if y is used to specify the method name instead. @@ -164,15 +193,94 @@ def corr(x, y=None, method=None): raise TypeError("Use 'method=' to specify method name.") if not y: - jx = _to_java_object_rdd(x.map(_convert_to_vector)) - resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) - bytes = sc._jvm.SerDe.dumps(resultMat) - ser = PickleSerializer() - return ser.loads(str(bytes)).toArray() + return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray() + else: + return callMLlibFunc("corr", x.map(float), y.map(float), method) + + @staticmethod + def chiSqTest(observed, expected=None): + """ + :: Experimental :: + + If `observed` is Vector, conduct Pearson's chi-squared goodness + of fit test of the observed data against the expected distribution, + or againt the uniform distribution (by default), with each category + having an expected frequency of `1 / len(observed)`. + (Note: `observed` cannot contain negative values) + + If `observed` is matrix, conduct Pearson's independence test on the + input contingency matrix, which cannot contain negative entries or + columns or rows that sum up to 0. + + If `observed` is an RDD of LabeledPoint, conduct Pearson's independence + test for every feature against the label across the input RDD. + For each feature, the (feature, label) pairs are converted into a + contingency matrix for which the chi-squared statistic is computed. + All label and feature values must be categorical. + + :param observed: it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + :param expected: Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + :return: ChiSquaredTest object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + >>> from pyspark.mllib.linalg import Vectors, Matrices + >>> observed = Vectors.dense([4, 6, 5]) + >>> pearson = Statistics.chiSqTest(observed) + >>> print pearson.statistic + 0.4 + >>> pearson.degreesOfFreedom + 2 + >>> print round(pearson.pValue, 4) + 0.8187 + >>> pearson.method + u'pearson' + >>> pearson.nullHypothesis + u'observed follows the same distribution as expected.' + + >>> observed = Vectors.dense([21, 38, 43, 80]) + >>> expected = Vectors.dense([3, 5, 7, 20]) + >>> pearson = Statistics.chiSqTest(observed, expected) + >>> print round(pearson.pValue, 4) + 0.0027 + + >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + >>> print round(chi.statistic, 4) + 21.9958 + + >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] + >>> rdd = sc.parallelize(data, 4) + >>> chi = Statistics.chiSqTest(rdd) + >>> print chi[0].statistic + 0.75 + >>> print chi[1].statistic + 1.5 + """ + if isinstance(observed, RDD): + if not isinstance(observed.first(), LabeledPoint): + raise ValueError("observed should be an RDD of LabeledPoint") + jmodels = callMLlibFunc("chiSqTest", observed) + return [ChiSqTestResult(m) for m in jmodels] + + if isinstance(observed, Matrix): + jmodel = callMLlibFunc("chiSqTest", observed) else: - jx = _to_java_object_rdd(x.map(float)) - jy = _to_java_object_rdd(y.map(float)) - return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) + if expected and len(expected) != len(observed): + raise ValueError("`expected` should have same length with `observed`") + jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) + return ChiSqTestResult(jmodel) def _test(): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378b4a..9fa4d6f6a2f5f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ def test_col_with_different_rdds(self): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 64ee79d83e849..5d1a3c0962796 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -15,36 +15,22 @@ # limitations under the License. # -from py4j.java_collections import MapConverter - from pyspark import SparkContext, RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint __all__ = ['DecisionTreeModel', 'DecisionTree'] -class DecisionTreeModel(object): +class DecisionTreeModel(JavaModelWrapper): """ A decision tree model for classification or regression. EXPERIMENTAL: This is an experimental API. - It will probably be modified for Spark v1.2. + It will probably be modified in future. """ - - def __init__(self, sc, java_model): - """ - :param sc: Spark context - :param java_model: Handle to Java model object - """ - self._sc = sc - self._java_model = java_model - - def __del__(self): - self._sc._gateway.detach(self._java_model) - def predict(self, x): """ Predict the label of one or more examples. @@ -52,24 +38,11 @@ def predict(self, x): :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ - SerDe = self._sc._jvm.SerDe - ser = PickleSerializer() if isinstance(x, RDD): - # Bulk prediction - first = x.take(1) - if not first: - return self._sc.parallelize([]) - if not isinstance(first[0], Vector): - x = x.map(_convert_to_vector) - jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD() - jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred) - return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) + return self.call("predict", x.map(_convert_to_vector)) else: - # Assume x is a single data point. - bytes = bytearray(ser.dumps(_convert_to_vector(x))) - vec = self._sc._jvm.SerDe.loads(bytes) - return self._java_model.predict(vec) + return self.call("predict", _convert_to_vector(x)) def numNodes(self): return self._java_model.numNodes() @@ -98,19 +71,13 @@ class DecisionTree(object): """ @staticmethod - def _train(data, type, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, - minInfoGain=0.0): + def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32, + minInstancesPerNode=1, minInfoGain=0.0): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" - sc = data.context - jrdd = _to_java_object_rdd(data) - cfiMap = MapConverter().convert(categoricalFeaturesInfo, - sc._gateway._gateway_client) - model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - jrdd, type, numClasses, cfiMap, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - return DecisionTreeModel(sc, model) + model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return DecisionTreeModel(model) @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 84b39a48619d2..4ed978b45409c 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,8 +18,7 @@ import numpy as np import warnings -from pyspark.rdd import RDD -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -162,20 +161,11 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) - >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect() - >>> type(loaded[0]) == LabeledPoint - True - >>> print examples[0] - (1.1,(3,[0,2],[-1.23,4.56e-07])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (0.0,[1.01,2.02,3.03]) + >>> MLUtils.loadLabeledPoints(sc, tempFile.name).collect() + [LabeledPoint(1.1, (3,[0,2],[-1.23,4.56e-07])), LabeledPoint(0.0, [1.01,2.02,3.03])] """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) - jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - jpyrdd = sc._jvm.SerDe.javaToPython(jrdd) - return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer())) + return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 15be4bfec92f9..08d047402625f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -120,7 +120,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._id = jrdd.id() self._partitionFunc = None - def _toPickleSerialization(self): - if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): - return self - else: - return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def _pickled(self): + return self._reserialize(AutoBatchedSerializer(PickleSerializer())) def id(self): """ @@ -316,9 +312,6 @@ def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD (relies on numpy and falls back on default random generator if numpy is unavailable). - - >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP - [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -449,12 +442,11 @@ def intersection(self, other): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer - if self._jrdd_deserializer == serializer: - return self - else: - converted = self.map(lambda x: x, preservesPartitioning=True) - converted._jrdd_deserializer = serializer - return converted + if self._jrdd_deserializer != serializer: + if not isinstance(self, PipelinedRDD): + self = self.map(lambda x: x, preservesPartitioning=True) + self._jrdd_deserializer = serializer + return self def __add__(self, other): """ @@ -529,6 +521,8 @@ def sortPartition(iterator): # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them rddSize = self.count() + if not rddSize: + return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() @@ -1123,9 +1117,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True) def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1150,9 +1143,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) @@ -1169,9 +1161,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False) def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1198,9 +1189,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, @@ -1218,9 +1208,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): :param path: path to sequence file :param compressionCodecClass: (None by default) """ - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True, path, compressionCodecClass) def saveAsPickleFile(self, path, batchSize=10): @@ -1235,8 +1224,11 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) [1, 2, 'rdd', 'spark'] """ - self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + if batchSize == 0: + ser = AutoBatchedSerializer(PickleSerializer()) + else: + ser = BatchedSerializer(PickleSerializer(), batchSize) + self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1777,13 +1769,10 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - if self.getNumPartitions() != other.getNumPartitions(): - raise ValueError("Can only zip with RDD which has the same number of partitions") - def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 0 + return 1 def batch_as(rdd, batchSize): ser = rdd._jrdd_deserializer @@ -1793,12 +1782,16 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: - # use the greatest batchSize to batch the other one. - if my_batch > other_batch: - other = batch_as(other, my_batch) - else: - self = batch_as(self, other_batch) + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) + + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") # There will be an Exception in JVM if there are different number # of items in each partitions. @@ -1867,11 +1860,11 @@ def setName(self, name): Assign a name to this RDD. >>> rdd1 = sc.parallelize([1,2]) - >>> rdd1.setName('RDD1') - >>> rdd1.name() + >>> rdd1.setName('RDD1').name() 'RDD1' """ self._jrdd.setName(name) + return self def toDebugString(self): """ @@ -1937,25 +1930,14 @@ def lookup(self, key): return values.collect() - def _is_pickled(self): - """ Return this RDD is serialized by Pickle or not. """ - der = self._jrdd_deserializer - if isinstance(der, PickleSerializer): - return True - if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): - return True - return False - def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ - if not self._is_pickled() else self - is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) + rdd = self._pickled() + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ @@ -2135,7 +2117,7 @@ def _test(): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 528a181e8905a..f5c3cfd259a5b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None): def initRandomGenerator(self, split): if self._use_numpy: import numpy - self._random = numpy.random.RandomState(self._seed) + self._random = numpy.random.RandomState(self._seed ^ split) else: - self._random = random.Random(self._seed) + self._random = random.Random(self._seed ^ split) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(0, 2 ** 32 - 1) + # mixing because the initial seeds are close to each other + for _ in xrange(10): + self._random.randint(0, 1) self._split = split self._rand_initialized = True diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 904bd9f2652d3..d597cbf94e1b1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,9 +33,8 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -By default, PySpark serialize objects in batches; the batch size can be -controlled through SparkContext's C{batchSize} parameter -(the default size is 1024 objects): +PySpark serialize objects in batches; By default, the batch size is chosen based +on the size of objects, also configurable by SparkContext's C{batchSize} parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -48,16 +47,6 @@ >>> rdd._jrdd.count() 8L >>> sc.stop() - -A batch size of -1 uses an unlimited batch size, and a size of 1 disables -batching: - ->>> sc = SparkContext('local', 'test', batchSize=1) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -16L """ import cPickle @@ -73,7 +62,7 @@ from pyspark import cloudpickle -__all__ = ["PickleSerializer", "MarshalSerializer"] +__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] class SpecialLengths(object): @@ -113,7 +102,7 @@ def __ne__(self, other): return not self.__eq__(other) def __repr__(self): - return "<%s object>" % self.__class__.__name__ + return "%s()" % self.__class__.__name__ def __hash__(self): return hash(str(self)) @@ -181,6 +170,7 @@ class BatchedSerializer(Serializer): """ UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): self.serializer = serializer @@ -213,10 +203,10 @@ def _load_stream_without_unbatching(self, stream): def __eq__(self, other): return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.batchSize == self.batchSize) def __repr__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -225,7 +215,7 @@ class AutoBatchedSerializer(BatchedSerializer): """ def __init__(self, serializer, bestSize=1 << 16): - BatchedSerializer.__init__(self, serializer, -1) + BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) self.bestSize = bestSize def dump_stream(self, iterator, stream): @@ -248,10 +238,10 @@ def dump_stream(self, iterator, stream): def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.bestSize == self.bestSize) def __str__(self): - return "AutoBatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % str(self.serializer) class CartesianDeserializer(FramedSerializer): @@ -284,7 +274,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "CartesianDeserializer<%s, %s>" % \ + return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -311,7 +301,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): @@ -430,7 +420,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol autumatically + Choose marshal or cPickle as serialization protocol automatically """ def __init__(self): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index d57a802e4734a..5931e923c2e36 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import random import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer try: import psutil @@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests - self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -470,7 +469,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) def _get_path(self, n): """ Choose one directory for spill by number n """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 93fd9d49096b8..e5d62a466cab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -35,6 +35,7 @@ import keyword import warnings import json +import re from array import array from operator import itemgetter from itertools import imap @@ -43,7 +44,8 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -108,6 +110,15 @@ def __eq__(self, other): return self is other +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + class StringType(PrimitiveType): """Spark SQL StringType @@ -148,13 +159,30 @@ class TimestampType(PrimitiveType): """ -class DecimalType(PrimitiveType): +class DecimalType(DataType): """Spark SQL DecimalType The data type representing decimal.Decimal values. """ + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + class DoubleType(PrimitiveType): @@ -313,12 +341,15 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable): + def __init__(self, name, dataType, nullable=True, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. + :param metadata: metadata of this field, which is a map from string + to simple type that can be serialized to JSON + automatically >>> (StructField("f1", StringType, True) ... == StructField("f1", StringType, True)) @@ -330,6 +361,7 @@ def __init__(self, name, dataType, nullable): self.name = name self.dataType = dataType self.nullable = nullable + self.metadata = metadata or {} def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, @@ -338,13 +370,15 @@ def __repr__(self): def jsonValue(self): return {"name": self.name, "type": self.dataType.jsonValue(), - "nullable": self.nullable} + "nullable": self.nullable, + "metadata": self.metadata} @classmethod def fromJson(cls, json): return StructField(json["name"], _parse_datatype_json_value(json["type"]), - json["nullable"]) + json["nullable"], + json["metadata"]) class StructType(DataType): @@ -384,6 +418,75 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -423,7 +526,8 @@ def _parse_datatype_json_string(json_string): ... StructField("simpleArray", simple_arraytype, True), ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False)]) + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) True >>> # Complex ArrayType. @@ -435,19 +539,43 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + def _parse_datatype_json_value(json_value): - if type(json_value) is unicode and json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType _type_mappings = { + type(None): NullType, bool: BooleanType, int: IntegerType, long: LongType, @@ -463,23 +591,34 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ if obj is None: raise ValueError("Can not infer type for None") + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): - if not obj: - raise ValueError("Can not infer type for empty dict") - key, value = obj.iteritems().next() - return MapType(_infer_type(key), _infer_type(value), True) + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): - if not obj: - raise ValueError("Can not infer type for empty list/array") - return ArrayType(_infer_type(obj[0]), True) + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) else: try: return _infer_schema(obj) @@ -512,60 +651,180 @@ def _infer_schema(row): return StructType(fields) -def _create_converter(obj, dataType): +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if isinstance(dataType, ArrayType): - conv = _create_converter(obj[0], dataType.elementType) + conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) elif isinstance(dataType, MapType): - value = obj.values()[0] - conv = _create_converter(value, dataType.valueType) + conv = _create_converter(dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif isinstance(dataType, NullType): + return lambda x: None + elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "fields"): + d = dict(zip(obj.fields, obj)) + if hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % obj) - if isinstance(obj, dict): - conv = lambda o: tuple(o.get(n) for n in names) - - elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple - conv = tuple - elif hasattr(obj, "__FIELDS__"): - conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - conv = lambda o: tuple(v for k, v in o) + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ else: - raise ValueError("unexpected tuple") - - elif hasattr(obj, "__dict__"): # object - conv = lambda o: [o.__dict__.get(n, None) for n in names] - - if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): - return conv - - row = conv(obj) - convs = [_create_converter(v, f.dataType) - for v, f in zip(row, dataType.fields)] - - def nested_conv(row): - return tuple(f(v) for f, v in zip(convs, conv(row))) + raise ValueError("Unexpected obj: %s" % obj) - return nested_conv + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - -def _drop_schema(rows, schema): - """ all the names of fields, becoming tuples""" - iterator = iter(rows) - row = iterator.next() - converter = _create_converter(row, schema) - yield converter(row) - for i in iterator: - yield converter(i) + return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} @@ -677,7 +936,7 @@ def _infer_schema_type(obj, dataType): return _infer_type(obj) if not obj: - raise ValueError("Can not infer type from empty value") + return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) @@ -739,11 +998,22 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType @@ -818,6 +1088,8 @@ def _has_struct_or_date(dt): return _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -888,6 +1160,9 @@ def Dict(d): elif isinstance(dataType, DateType): return datetime.date + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -959,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -989,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()): """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, None, - BatchedSerializer(PickleSerializer(), 1024), - BatchedSerializer(PickleSerializer(), 1024)) + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M @@ -1013,18 +1287,20 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) - def inferSchema(self, rdd): + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. - We peek at the first row of the RDD to determine the fields' names - and types. Nested collections are supported, which include array, - dict, list, Row, tuple, namedtuple, or object. + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. - All the rows in `rdd` should have the same type with the first one, - or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. - Each row could be L{pyspark.sql.Row} object or namedtuple or objects, - using dict is deprecated. + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. >>> rdd = sc.parallelize( ... [Row(field1=1, field2="row1"), @@ -1061,8 +1337,23 @@ def inferSchema(self, rdd): warnings.warn("Using RDD of dict to inferSchema is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): @@ -1148,8 +1439,11 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1183,7 +1477,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema=None): + def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a L{SchemaRDD}. @@ -1191,8 +1485,8 @@ def jsonFile(self, path, schema=None): If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -1238,20 +1532,20 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) - def jsonRDD(self, rdd, schema=None): + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") @@ -1308,7 +1602,7 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) @@ -1400,33 +1694,6 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): - """Starts up an instance of hive where metadata is stored locally. - - An in-process metadata data is created with data stored in ./metadata. - Warehouse data is stored in in ./warehouse. - - >>> import os - >>> hiveCtx = LocalHiveContext(sc) - >>> try: - ... supress = hiveCtx.sql("DROP TABLE src") - ... except Exception: - ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], - ... 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.sql( - ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" - ... % kv1) - >>> results = hiveCtx.sql("FROM src SELECT value" - ... ).map(lambda r: int(r.value.split('_')[1])) - >>> num = results.count() - >>> reduce_sum = results.reduce(lambda x, y: x + y) - >>> num - 500 - >>> reduce_sum - 130091 - """ - def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " @@ -1573,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_checkpointed = False self.ctx = self.sql_ctx._sc # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) + self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -1803,15 +2070,13 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( @@ -1819,6 +2084,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 37a128907b3a7..491e445a216bf 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -241,7 +242,7 @@ class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name, batchSize=2) + self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() @@ -252,7 +253,7 @@ class ReusedPySparkTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + cls.sc = SparkContext('local[4]', cls.__name__) @classmethod def tearDownClass(cls): @@ -648,6 +649,24 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + class ProfilerTests(PySparkTestCase): @@ -655,7 +674,7 @@ def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): @@ -679,8 +698,65 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -781,6 +857,25 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], srdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) + srdd.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(srdd.schema(), srdd2.schema()) + self.assertEqual({}, srdd2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) + srdd2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) @@ -790,6 +885,39 @@ def test_convert_row_to_dict(self): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): @@ -887,16 +1015,19 @@ def test_sequencefiles(self): clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) - ec = (u'1', - {u'__class__': u'org.apache.spark.api.python.TestWritable', - u'double': 54.0, u'int': 123, u'str': u'test1'}) - self.assertEqual(clazz[0], ec) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) - self.assertEqual(unbatched_clazz[0], ec) + ).collect()) + self.assertEqual(unbatched_clazz, ec) def test_oldhadoop(self): basepath = self.tempdir.name @@ -982,6 +1113,25 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = "short binary data" + with open(os.path.join(path, "part-0000"), 'w') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(range(100), result) + class OutputFormatTests(ReusedPySparkTestCase): @@ -1216,51 +1366,6 @@ def test_reserialization(self): result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) - def test_unbatched_save_and_read(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( - basepath + "/unbatched/") - - unbatched_sequence = sorted(self.sc.sequenceFile( - basepath + "/unbatched/", - batchSize=1).collect()) - self.assertEqual(unbatched_sequence, ei) - - unbatched_hadoopFile = sorted(self.sc.hadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopFile, ei) - - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopFile, ei) - - oldconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=oldconf, - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopRDD, ei) - - newconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=newconf, - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopRDD, ei) - def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash index 7f05d2ef491a3..055e206662654 100755 --- a/sbt/sbt-launch-lib.bash +++ b/sbt/sbt-launch-lib.bash @@ -124,7 +124,8 @@ require_arg () { local opt="$2" local arg="$3" if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then - die "$opt requires <$type> argument" + echo "$opt requires <$type> argument" 1>&2 + exit 1 fi } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d76c743d3f652..71034c2c43c77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,85 +19,145 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} +import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal + /** - * Provides experimental support for generating catalyst schemas for scala objects. + * A default version of ScalaReflection that uses the runtime universe. */ -object ScalaReflection { +object ScalaReflection extends ScalaReflection { + val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe +} + +/** + * Support for generating catalyst schemas for scala objects. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + import universe._ + // The Predef.Map is scala.collection.immutable.Map. // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - import scala.reflect.runtime.universe._ case class Schema(dataType: DataType, nullable: Boolean) - /** Converts Scala objects to catalyst rows / types */ - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.map(convertToCatalyst).orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other + /** + * Converts Scala objects to catalyst rows / types. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) + case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + case (p: Product, structType: StructType) => + new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (d: BigDecimal, _) => Decimal(d) + case (other, _) => other + } + + /** Converts Catalyst types used internally in rows to standard Scala types */ + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => udt.deserialize(d) + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + case (r: Row, s: StructType) => convertRowToScala(r, s) + case (d: Decimal, _: DecimalType) => d.toBigDecimal + case (other, _) => other + } + + def convertRowToScala(r: Row, schema: StructType): Row = { + new GenericRow( + r.zip(schema.fields.map(_.dataType)) + .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_,_]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + def schemaFor(tpe: `Type`): Schema = { + val className: String = tpe.erasure.typeSymbol.asClass.fullName + tpe match { + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, + // whereas className is from Scala reflection. This can make it hard to find classes + // in some cases, such as when a class is enclosed in an object (in which case + // Java appends a '$' to the object name but Scala does not). + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + Schema(udt, nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + } } def typeOfObject: PartialFunction[Any, DataType] = { @@ -111,8 +171,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DecimalType.JvmType => DecimalType case obj: DateType.JvmType => DateType + case obj: BigDecimal => DecimalType.Unlimited + case obj: Decimal => DecimalType.Unlimited case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala index 12e8346a6445d..f5c19ee69c37a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -137,7 +137,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val LAZY = Keyword("LAZY") protected val SET = Keyword("SET") protected val TABLE = Keyword("TABLE") - protected val SOURCE = Keyword("SOURCE") protected val UNCACHE = Keyword("UNCACHE") protected implicit def asParser(k: Keyword): Parser[String] = @@ -152,8 +151,7 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr override val lexical = new SqlLexical(reservedWords) - override protected lazy val start: Parser[LogicalPlan] = - cache | uncache | set | shell | source | others + override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -171,16 +169,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case input => SetCommandParser(input) } - private lazy val shell: Parser[LogicalPlan] = - "!" ~> restInput ^^ { - case input => ShellCommand(input.trim) - } - - private lazy val source: Parser[LogicalPlan] = - SOURCE ~> restInput ^^ { - case input => SourceCommand(input.trim) - } - private lazy val others: Parser[LogicalPlan] = wholeInput ^^ { case input => fallback(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0acf7252ba3f0..affef276c2a88 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -52,8 +52,10 @@ class SqlParser extends AbstractSparkSQLParser { protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") + protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val DOUBLE = Keyword("DOUBLE") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") @@ -166,7 +168,8 @@ class SqlParser extends AbstractSparkSQLParser { // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = - ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) } + ( relation ~ rep1("," ~> relation) ^^ { + case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } } | relation ) @@ -231,12 +234,15 @@ class SqlParser extends AbstractSparkSQLParser { | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { - case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { + case e ~ not ~ el ~ eu => + val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + not.fold(betweenExpr)(f=> Not(betweenExpr)) } | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { case e1 ~ e2 => In(e1, e2) } @@ -382,5 +388,15 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val dataType: Parser[DataType] = - STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType + ( STRING ^^^ StringType + | TIMESTAMP ^^^ TimestampType + | DOUBLE ^^^ DoubleType + | fixedDecimalType + | DECIMAL ^^^ DecimalType.Unlimited + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 2059a91ba0612..0415d74bd8141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -28,6 +28,8 @@ trait Catalog { def caseSensitive: Boolean + def tableExists(db: Option[String], tableName: String): Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -82,6 +84,14 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables.clear() } + override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + tables.get(tblName) match { + case Some(_) => true + case None => false + } + } + override def lookupRelation( databaseName: Option[String], tableName: String, @@ -107,6 +117,14 @@ trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + abstract override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + overrides.get((dbName, tblName)) match { + case Some(_) => true + case None => super.tableExists(db, tableName) + } + } + abstract override def lookupRelation( databaseName: Option[String], tableName: String, @@ -149,6 +167,10 @@ object EmptyCatalog extends Catalog { val caseSensitive: Boolean = true + def tableExists(db: Option[String], tableName: String): Boolean = { + throw new UnsupportedOperationException + } + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2b69c02b28285..e38114ab3cf25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil + private val numericPrecedence = + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited) + /** + * Find the tightest common type of two types that might be used in a binary expression. + * This handles all numeric types except fixed-precision decimals interacting with each other or + * with primitive types, because in that case the precision and scale of the result depends on + * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. + */ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { val valueTypes = Seq(t1, t2).filter(t => t != NullType) if (valueTypes.distinct.size > 1) { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) { + Some(numericPrecedence.filter(t => t == t1 || t == t2).last) + } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) { + // Fixed-precision decimals can up-cast into unlimited + if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) { + Some(DecimalType.Unlimited) + } else { + None + } + } else { + None + } } else { Some(if (valueTypes.size == 0) NullType else valueTypes.head) } @@ -59,6 +71,7 @@ trait HiveTypeCoercion { ConvertNaNs :: WidenTypes :: PromoteStrings :: + DecimalPrecision :: BooleanComparisons :: BooleanCasts :: StringToIntegralCasts :: @@ -151,6 +164,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. @@ -265,6 +279,110 @@ trait HiveTypeCoercion { } } + // scalastyle:off + /** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * sum(e1) p1 + 10 s1 + * avg(e1) p1 + 4 s1 + 4 + * + * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + */ + // scalastyle:on + object DecimalPrecision extends Rule[LogicalPlan] { + import scala.math.{max, min} + + // Conversion rules for integer types into fixed-precision decimals + val intTypeToFixed: Map[DataType, DecimalType] = Map( + ByteType -> DecimalType(3, 0), + ShortType -> DecimalType(5, 0), + IntegerType -> DecimalType(10, 0), + LongType -> DecimalType(20, 0) + ) + + def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } + } + /** * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ @@ -330,7 +448,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType), t) + Cast(Cast(e, DecimalType.Unlimited), t) } } @@ -383,10 +501,12 @@ trait HiveTypeCoercion { // Decimal and Double remain the same case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) - case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) + case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => + Divide(l, Cast(r, DecimalType.Unlimited)) + case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => + Divide(Cast(l, DecimalType.Unlimited), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java new file mode 100644 index 0000000000000..e966aeea1cb23 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.annotation; + +import java.lang.annotation.*; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.sql.catalyst.types.UserDefinedType; + +/** + * ::DeveloperApi:: + * A user-defined type which can be automatically recognized by a SQLContext and registered. + * + * WARNING: This annotation will only work if both Java and Scala reflection return the same class + * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class + * is enclosed in an object (a singleton). + * + * WARNING: UDTs are currently only supported from Scala. + */ +// TODO: Should I used @Documented ? +@DeveloperApi +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface SQLUserDefinedType { + + /** + * Returns an instance of the UserDefinedType which can serialize and deserialize the user + * class to and from Catalyst built-in types. + */ + Class > udt(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 23cfd483ec410..31dc5a58e68e5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ @@ -107,7 +110,8 @@ package object dsl { def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) - def as(s: Symbol) = Alias(expr, s.name)() + def as(alias: String) = Alias(expr, alias)() + def as(alias: Symbol) = Alias(expr, alias.name)() } trait ExpressionConversions { @@ -124,7 +128,8 @@ package object dsl { implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) implicit def dateToLiteral(d: Date) = Literal(d) - implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) + implicit def decimalToLiteral(d: Decimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -183,7 +188,11 @@ package object dsl { def date = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType, nullable = true)() + def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int) = + AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() @@ -278,4 +287,62 @@ package object dsl { def writeToFile(path: String) = WriteToFile(path, logicalPlan) } } + + case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { + def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } + + // scalastyle:off + /** functionToUdfBuilder 1-22 were generated by this script + + (1 to 22).map { x => + val argTypes = Seq.fill(x)("_").mkString(", ") + s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)" + } + */ + + implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + // scalastyle:on } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8e5baf0eb82d6..55319e7a79103 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { @@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, DateType) => true case (DateType, _: NumericType) => true case (DateType, BooleanType) => true + case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => child.nullable } @@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Short](_, _ != 0) case ByteType => buildCast[Byte](_, _ != 0) - case DecimalType => - buildCast[BigDecimal](_, _ != 0) + case DecimalType() => + buildCast[Decimal](_, _ != 0) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp - case DecimalType => - buildCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType() => + buildCast[Decimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp case DoubleType => - buildCast[Double](_, d => decimalToTimestamp(d)) + buildCast[Double](_, d => decimalToTimestamp(Decimal(d))) // TimestampWritable.floatToTimestamp case FloatType => - buildCast[Float](_, f => decimalToTimestamp(f)) + buildCast[Float](_, f => decimalToTimestamp(Decimal(f))) } - private[this] def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: Decimal) = { val seconds = Math.floor(d.toDouble).toLong - val bd = (d - seconds) * 1000000000 + val bd = (d.toBigDecimal - seconds) * 1000000000 val nanos = bd.intValue() val millis = seconds * 1000 @@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toLong) + case DecimalType() => + buildCast[Decimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => - buildCast[BigDecimal](_, _.toInt) + case DecimalType() => + buildCast[Decimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => - buildCast[BigDecimal](_, _.toShort) + case DecimalType() => + buildCast[Decimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => - buildCast[BigDecimal](_, _.toByte) + case DecimalType() => + buildCast[Decimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } - // DecimalConverter - private[this] def castToDecimal: Any => Any = child.dataType match { + /** + * Change the precision / scale in a given decimal to those set in `decimalType` (if any), + * returning null if it overflows or modifying `value` in-place and returning it if successful. + * + * NOTE: this modifies `value` in-place, so don't call it on external data. + */ + private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + decimalType match { + case DecimalType.Unlimited => + value + case DecimalType.Fixed(precision, scale) => + if (value.changePrecision(precision, scale)) value else null + } + } + + private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match { case StringType => - buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Date](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => - b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + case DecimalType() => + b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + case LongType => + b => changePrecision(Decimal(b.asInstanceOf[Long]), target) + case x: NumericType => // All other numeric types can be represented precisely as Doubles + b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) } // DoubleConverter @@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toDouble) + case DecimalType() => + buildCast[Decimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => - buildCast[BigDecimal](_, _.toFloat) + case DecimalType() => + buildCast[Decimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } @@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case DecimalType => castToDecimal case DateType => castToDate + case decimal: DecimalType => castToDecimal(decimal) case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1eb260efa6387..39b120e8de485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} +import org.apache.spark.sql.catalyst.util.Metadata abstract class Expression extends TreeNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 1b687a443ef8b..18c96da2f87fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner +/** + * User-defined function. + * @param dataType Return type of function. + */ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { @@ -30,323 +34,369 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def toString = s"scalaUDF(${children.mkString(",")})" + // scalastyle:off + /** This method has been generated by this script (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) s""" case $x => function.asInstanceOf[($anys) => Any]( - $evals) + $evals) """ - } + }.foreach(println) */ - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() - case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) + case 1 => + function.asInstanceOf[(Any) => Any]( + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) + + case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) + + case 3 => function.asInstanceOf[(Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) + + case 4 => function.asInstanceOf[(Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) + + case 5 => function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) + + case 6 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) + + case 7 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) + + case 8 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) + + case 9 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) + + case 10 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) + + case 11 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) + + case 12 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) + + case 13 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) + + case 14 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) + + case 15 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) + + case 16 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) + + case 17 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) + + case 18 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) + + case 19 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) + + case 20 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) + + case 21 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) + + case 22 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input), - children(21).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), + ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) + } // scalastyle:on - ScalaReflection.convertToCatalyst(result) + ScalaReflection.convertToCatalyst(result, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 1b4d892625dbb..2b364fc1df1d8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = DoubleType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) + child.dataType match { + case DecimalType.Fixed(_, _) => + // Turn the results to unlimited decimals for the divsion, before going back to fixed + val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) + val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + SplitEvaluation( + Cast(Divide(castedSum, castedCount), dataType), + partialCount :: partialSum :: Nil) + + case _ => + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } } override def newInstance() = new AverageFunction(child, this) @@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { @@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM(DISTINCT $child)" override def newInstance() = new SumDistinctFunction(child, this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 83e8466ec2aa7..8574cabc43525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - + def dataType = DoubleType override def foldable = child.foldable def nullable = child.nullable @@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression { def nullable = left.nullable || right.nullable override lazy val resolved = - left.resolved && right.resolved && left.dataType == right.dataType + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) def dataType = { if (!resolved) { @@ -104,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "/" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = dataType match { case _: FractionalType => f2(input, left, right, _.div(_, _)) case _: IntegralType => i2(input, left , right, _.quot(_, _)) @@ -114,6 +118,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5a3f013c34579..67f8d411b6bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import com.google.common.cache.{CacheLoader, CacheBuilder} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.existentials @@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case UnscaledValue(child) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: Long = if (!$nullTerm) { + ${childEval.primitiveTerm}.toUnscaledLong + } else { + ${defaultPrimitive(LongType)} + } + """.children + + case MakeDecimal(child, precision, scale) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal = + ${defaultPrimitive(DecimalType())} + + if (!$nullTerm) { + $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal() + $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) + $nullTerm = $primitiveTerm == null + } + """.children } // If there was no match in the partial function above, we fall back on calling the interpreted @@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case LongType => ru.Literal(Constant(1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala new file mode 100644 index 0000000000000..d1eab2eb4ed56 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType} + +/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +case class UnscaledValue(child: Expression) extends UnaryExpression { + override type EvaluatedType = Any + + override def dataType: DataType = LongType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"UnscaledValue($child)" + + override def eval(input: Row): Any = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + childResult.asInstanceOf[Decimal].toUnscaledLong + } + } +} + +/** Create a Decimal from an unscaled Long value */ +case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { + override type EvaluatedType = Decimal + + override def dataType: DataType = DecimalType(precision, scale) + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"MakeDecimal($child,$precision,$scale)" + + override def eval(input: Row): Decimal = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9c865254e0be9..ab0701fd9a80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -43,7 +43,7 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) + ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) override def nullable = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ba240233cae61..93c19325151bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal object Literal { def apply(v: Any): Literal = v match { @@ -31,7 +32,8 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(d, DecimalType) + case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) + case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) @@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { } // TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) +case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index fe13a661f6f7a..fc90a54a58259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util.Metadata object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() @@ -43,6 +44,9 @@ abstract class NamedExpression extends Expression { def toAttribute: Attribute + /** Returns the metadata when an expression is a reference to another expression with metadata. */ + def metadata: Metadata = Metadata.empty + protected def typeSuffix = if (resolved) { dataType match { @@ -88,10 +92,16 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable + override def metadata: Metadata = { + child match { + case named: NamedExpression => named.metadata + case _ => Metadata.empty + } + } override def toAttribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) + AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { UnresolvedAttribute(name) } @@ -108,18 +118,23 @@ case class Alias(child: Expression, name: String) * @param name The name of this attribute, should only be used during analysis or for debugging. * @param dataType The [[DataType]] of this attribute. * @param nullable True if null is a valid value for this attribute. + * @param metadata The metadata of this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. * @param qualifiers a list of strings that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. */ -case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true) - (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) - extends Attribute with trees.LeafNode[Expression] { +case class AttributeReference( + name: String, + dataType: DataType, + nullable: Boolean = true, + override val metadata: Metadata = Metadata.empty)( + val exprId: ExprId = NamedExpression.newExprId, + val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { override def equals(other: Any) = other match { - case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -128,10 +143,12 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea var h = 17 h = h * 37 + exprId.hashCode() h = h * 37 + dataType.hashCode() + h = h * 37 + metadata.hashCode() h } - override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance() = + AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -140,7 +157,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability)(exprId, qualifiers) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers) } } @@ -159,7 +176,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea if (newQualifiers.toSet == qualifiers.toSet) { this } else { - AttributeReference(name, dataType, nullable)(exprId, newQualifiers) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ce7c78195830..a4aa322fc52d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal abstract class Optimizer extends RuleExecutor[LogicalPlan] @@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer { SimplifyCasts, SimplifyCaseConversionExpressions, OptimizeIn) :: + Batch("Decimal Optimizations", FixedPoint(100), + DecimalAggregates) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the - * attributes of the left or right side of sub query when applicable. - * + * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * attributes of the left or right side of sub query when applicable. + * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { @@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = condition.partition(_.references subsetOf left.outputSet) - val (rightEvaluateCondition, commonCondition) = + val (rightEvaluateCondition, commonCondition) = rest.partition(_.references subsetOf right.outputSet) (leftEvaluateCondition, rightEvaluateCondition, commonCondition) @@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => - val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = + val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the join filter into sub query scanning if applicable case f @ Join(left, right, joinType, joinCondition) => - val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { @@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } } + +/** + * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. + * + * This uses the same rules for increasing the precision and scale of the output as + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]]. + */ +object DecimalAggregates extends Rule[LogicalPlan] { + import Decimal.MAX_LONG_DIGITS + + /** Maximum number of decimal digits representable precisely in a Double */ + val MAX_DOUBLE_DIGITS = 15 + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + Cast( + Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index bdd07bbeb2230..a38079ced34b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +/** + * Catalyst is a library for manipulating relational query plans. All classes in catalyst are + * considered an internal API to Spark SQL and are subject to change between minor releases. + */ package object catalyst { /** * A JVM-global lock that should be used to prevent thread safety issues when using things in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5839c9f7c43ef..51b5699affed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -21,6 +21,15 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode +/** + * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * be used for execution. If this strategy does not apply to the give logical operation then an + * empty list should be returned. + */ +abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging { + def apply(plan: LogicalPlan): Seq[PhysicalPlan] +} + /** * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which @@ -35,16 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode */ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** A list of execution strategies that can be used by the planner */ - def strategies: Seq[Strategy] - - /** - * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can - * be used for execution. If this strategy does not apply to the give logical operation then an - * empty list should be returned. - */ - abstract protected class Strategy extends Logging { - def apply(plan: LogicalPlan): Seq[PhysicalPlan] - } + def strategies: Seq[GenericStrategy[PhysicalPlan]] /** * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index b8ba2ee428a20..1d513d7789763 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.types.StringType /** @@ -41,6 +41,15 @@ case class NativeCommand(cmd: String) extends Command { /** * Commands of the form "SET [key [= value] ]". */ +case class DFSCommand(kv: Option[(String, Option[String])]) extends Command { + override def output = Seq( + AttributeReference("DFS output", StringType, nullable = false)()) +} + +/** + * + * Commands of the form "SET [key [= value] ]". + */ case class SetCommand(kv: Option[(String, Option[String])]) extends Command { override def output = Seq( AttributeReference("", StringType, nullable = false)()) @@ -81,14 +90,3 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } - -/** - * Returned for the "! shellCommand" command - */ -case class ShellCommand(cmd: String) extends Command - - -/** - * Returned for the "SOURCE file" command - */ -case class SourceCommand(filePath: String) extends Command diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index b9cf37d53ffd2..5dd19dd12d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,21 +19,23 @@ package org.apache.spark.sql.catalyst.types import java.sql.{Date, Timestamp} -import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} +import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers -import org.json4s.JsonAST.JValue import org.json4s._ +import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} +import org.apache.spark.sql.catalyst.types.decimal._ +import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils - object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -66,17 +68,25 @@ object DataType { ("fields", JArray(fields)), ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } private def parseStructField(json: JValue): StructField = json match { case JSortedObject( + ("metadata", metadata: JObject), ("name", JString(name)), ("nullable", JBool(nullable)), ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable) + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) } - @deprecated("Use DataType.fromJson instead") + @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) private object CaseClassStringParser extends RegexParsers { @@ -90,11 +100,17 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType - | "DecimalType" ^^^ DecimalType | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType | "TimestampType" ^^^ TimestampType ) + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) @@ -199,10 +215,18 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++ - NativeType.all - - private[sql] val nameToType = all.map(t => t.typeName -> t).toMap + private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap + + /** Given the string representation of a type, return its DataType */ + private[sql] def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } } abstract class NativeType extends DataType { @@ -326,18 +350,64 @@ object FractionalType { case _ => false } } + abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] } -case object DecimalType extends FractionalType { - private[sql] type JvmType = BigDecimal +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + +/** A Decimal that might have fixed precision and scale, or unlimited values for these */ +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[BigDecimal]] - private[sql] val fractional = implicitly[Fractional[BigDecimal]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - private[sql] val asIntegral = BigDecimalAsIfIntegral + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } +} + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } } case object DoubleType extends FractionalType { @@ -388,24 +458,34 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT * @param name The name of this field. * @param dataType The data type of this field. * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. */ -case class StructField(name: String, dataType: DataType, nullable: Boolean) { +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean = true, + metadata: Metadata = Metadata.empty) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + private[sql] def jsonValue: JValue = { ("name" -> name) ~ ("type" -> dataType.jsonValue) ~ - ("nullable" -> nullable) + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) } case class StructType(fields: Seq[StructField]) extends DataType { @@ -439,7 +519,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { } protected[sql] def toAttributes = - fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) def treeString: String = { val builder = new StringBuilder @@ -494,3 +574,50 @@ case class MapType( ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) } + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a SchemaRDD which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, + * where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala new file mode 100644 index 0000000000000..708362acf32dc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.apache.spark.annotation.DeveloperApi + +/** + * A mutable implementation of BigDecimal that can hold a Long if values are small enough. + * + * The semantics of the fields are as follows: + * - _precision and _scale represent the SQL precision and scale we are looking for + * - If decimalVal is set, it represents the whole decimal value + * - Otherwise, the decimal value is longVal / (10 ** _scale) + */ +final class Decimal extends Ordered[Decimal] with Serializable { + import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} + + private var decimalVal: BigDecimal = null + private var longVal: Long = 0L + private var _precision: Int = 1 + private var _scale: Int = 0 + + def precision: Int = _precision + def scale: Int = _scale + + /** + * Set this Decimal to the given Long. Will have precision 20 and scale 0. + */ + def set(longVal: Long): Decimal = { + if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + this.decimalVal = null + this.longVal = longVal + } + this._precision = 20 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given Int. Will have precision 10 and scale 0. + */ + def set(intVal: Int): Decimal = { + this.decimalVal = null + this.longVal = intVal + this._precision = 10 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale. + */ + def set(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (setOrNull(unscaled, precision, scale) == null) { + throw new IllegalArgumentException("Unscaled value too large for precision") + } + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale, + * and return it, or return null if it cannot be set due to overflow. + */ + def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + if (precision < 19) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) + if (unscaled <= -p || unscaled >= p) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = null + this.longVal = unscaled + } + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, with a given precision and scale. + */ + def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + require(decimalVal.precision <= precision, "Overflowed precision") + this.longVal = 0L + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + */ + def set(decimal: BigDecimal): Decimal = { + this.decimalVal = decimal + this.longVal = 0L + this._precision = decimal.precision + this._scale = decimal.scale + this + } + + /** + * Set this Decimal to the given Decimal value. + */ + def set(decimal: Decimal): Decimal = { + this.decimalVal = decimal.decimalVal + this.longVal = decimal.longVal + this._precision = decimal._precision + this._scale = decimal._scale + this + } + + def toBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + + def toUnscaledLong: Long = { + if (decimalVal.ne(null)) { + decimalVal.underlying().unscaledValue().longValue() + } else { + longVal + } + } + + override def toString: String = toBigDecimal.toString() + + @DeveloperApi + def toDebugString: String = { + if (decimalVal.ne(null)) { + s"Decimal(expanded,$decimalVal,$precision,$scale})" + } else { + s"Decimal(compact,$longVal,$precision,$scale})" + } + } + + def toDouble: Double = toBigDecimal.doubleValue() + + def toFloat: Float = toBigDecimal.floatValue() + + def toLong: Long = { + if (decimalVal.eq(null)) { + longVal / POW_10(_scale) + } else { + decimalVal.longValue() + } + } + + def toInt: Int = toLong.toInt + + def toShort: Short = toLong.toShort + + def toByte: Byte = toLong.toByte + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + def changePrecision(precision: Int, scale: Int): Boolean = { + // First, update our longVal if we can, or transfer over to using a BigDecimal + if (decimalVal.eq(null)) { + if (scale < _scale) { + // Easier case: we just need to divide our scale down + val diff = _scale - scale + val droppedDigits = longVal % POW_10(diff) + longVal /= POW_10(diff) + if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { + longVal += (if (longVal < 0) -1L else 1L) + } + } else if (scale > _scale) { + // We might be able to multiply longVal by a power of 10 and not overflow, but if not, + // switch to using a BigDecimal + val diff = scale - _scale + val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) + if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { + // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS + longVal *= POW_10(diff) + } else { + // Give up on using Longs; switch to BigDecimal, which we'll modify below + decimalVal = BigDecimal(longVal, _scale) + } + } + // In both cases, we will check whether our precision is okay below + } + + if (decimalVal.ne(null)) { + // We get here if either we started with a BigDecimal, or we switched to one because we would + // have overflowed our Long; in either case we must rescale decimalVal to the new scale. + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + if (newVal.precision > precision) { + return false + } + decimalVal = newVal + } else { + // We're still using Longs, but we should check whether we match the new precision + val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + if (longVal <= -p || longVal >= p) { + // Note that we shouldn't have been able to fix this by switching to BigDecimal + return false + } + } + + _precision = precision + _scale = scale + true + } + + override def clone(): Decimal = new Decimal().set(this) + + override def compare(other: Decimal): Int = { + if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { + if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 + } else { + toBigDecimal.compare(other.toBigDecimal) + } + } + + override def equals(other: Any) = other match { + case d: Decimal => + compare(d) == 0 + case _ => + false + } + + override def hashCode(): Int = toBigDecimal.hashCode() + + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + + def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + + def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + + def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + + def % (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + + def remainder(that: Decimal): Decimal = this % that + + def unary_- : Decimal = { + if (decimalVal.ne(null)) { + Decimal(-decimalVal) + } else { + Decimal(-longVal, precision, scale) + } + } +} + +object Decimal { + private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + + /** Maximum number of decimal digits a Long can represent */ + val MAX_LONG_DIGITS = 18 + + private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + + private val BIG_DEC_ZERO = BigDecimal(0) + + def apply(value: Double): Decimal = new Decimal().set(value) + + def apply(value: Long): Decimal = new Decimal().set(value) + + def apply(value: Int): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = + new Decimal().set(unscaled, precision, scale) + + def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two + // parameters inheriting from a common trait since both traits define mkNumericOps. + // See scala.math's Numeric.scala for examples for Scala's built-in types. + + /** Common methods for Decimal evidence parameters */ + trait DecimalIsConflicted extends Numeric[Decimal] { + override def plus(x: Decimal, y: Decimal): Decimal = x + y + override def times(x: Decimal, y: Decimal): Decimal = x * y + override def minus(x: Decimal, y: Decimal): Decimal = x - y + override def negate(x: Decimal): Decimal = -x + override def toDouble(x: Decimal): Double = x.toDouble + override def toFloat(x: Decimal): Float = x.toFloat + override def toInt(x: Decimal): Int = x.toInt + override def toLong(x: Decimal): Long = x.toLong + override def fromInt(x: Int): Decimal = new Decimal().set(x) + override def compare(x: Decimal, y: Decimal): Int = x.compare(y) + } + + /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ + object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + override def div(x: Decimal, y: Decimal): Decimal = x / y + } + + /** A [[scala.math.Integral]] evidence parameter for Decimals. */ + object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + override def quot(x: Decimal, y: Decimal): Decimal = x / y + override def rem(x: Decimal, y: Decimal): Decimal = x % y + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala new file mode 100644 index 0000000000000..2f2082fa3c863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata$#fromJson]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ +sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable { + + /** Gets a Long. */ + def getLong(key: String): Long = get(key) + + /** Gets a Double. */ + def getDouble(key: String): Double = get(key) + + /** Gets a Boolean. */ + def getBoolean(key: String): Boolean = get(key) + + /** Gets a String. */ + def getString(key: String): String = get(key) + + /** Gets a Metadata. */ + def getMetadata(key: String): Metadata = get(key) + + /** Gets a Long array. */ + def getLongArray(key: String): Array[Long] = get(key) + + /** Gets a Double array. */ + def getDoubleArray(key: String): Array[Double] = get(key) + + /** Gets a Boolean array. */ + def getBooleanArray(key: String): Array[Boolean] = get(key) + + /** Gets a String array. */ + def getStringArray(key: String): Array[String] = get(key) + + /** Gets a Metadata array. */ + def getMetadataArray(key: String): Array[Metadata] = get(key) + + /** Converts to its JSON representation. */ + def json: String = compact(render(jsonValue)) + + override def toString: String = json + + override def equals(obj: Any): Boolean = { + obj match { + case that: Metadata => + if (map.keySet == that.map.keySet) { + map.keys.forall { k => + (map(k), that.map(k)) match { + case (v0: Array[_], v1: Array[_]) => + v0.view == v1.view + case (v0, v1) => + v0 == v1 + } + } + } else { + false + } + case other => + false + } + } + + override def hashCode: Int = Metadata.hash(this) + + private def get[T](key: String): T = { + map(key).asInstanceOf[T] + } + + private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) +} + +object Metadata { + + /** Returns an empty Metadata. */ + def empty: Metadata = new Metadata(Map.empty) + + /** Creates a Metadata instance from JSON. */ + def fromJson(json: String): Metadata = { + fromJObject(parse(json).asInstanceOf[JObject]) + } + + /** Creates a Metadata instance from JSON AST. */ + private[sql] def fromJObject(jObj: JObject): Metadata = { + val builder = new MetadataBuilder + jObj.obj.foreach { + case (key, JInt(value)) => + builder.putLong(key, value.toLong) + case (key, JDouble(value)) => + builder.putDouble(key, value) + case (key, JBool(value)) => + builder.putBoolean(key, value) + case (key, JString(value)) => + builder.putString(key, value) + case (key, o: JObject) => + builder.putMetadata(key, fromJObject(o)) + case (key, JArray(value)) => + if (value.isEmpty) { + // If it is an empty array, we cannot infer its element type. We put an empty Array[Long]. + builder.putLongArray(key, Array.empty) + } else { + value.head match { + case _: JInt => + builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) + case _: JDouble => + builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) + case _: JBool => + builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray) + case _: JString => + builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray) + case _: JObject => + builder.putMetadataArray( + key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) + case other => + throw new RuntimeException(s"Do not support array of type ${other.getClass}.") + } + } + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + builder.build() + } + + /** Converts to JSON AST. */ + private def toJsonValue(obj: Any): JValue = { + obj match { + case map: Map[_, _] => + val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) } + JObject(fields) + case arr: Array[_] => + val values = arr.toList.map(toJsonValue) + JArray(values) + case x: Long => + JInt(x) + case x: Double => + JDouble(x) + case x: Boolean => + JBool(x) + case x: String => + JString(x) + case x: Metadata => + toJsonValue(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } + + /** Computes the hash code for the types we support. */ + private def hash(obj: Any): Int = { + obj match { + case map: Map[_, _] => + map.mapValues(hash).## + case arr: Array[_] => + // Seq.empty[T] has the same hashCode regardless of T. + arr.toSeq.map(hash).## + case x: Long => + x.## + case x: Double => + x.## + case x: Boolean => + x.## + case x: String => + x.## + case x: Metadata => + hash(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } +} + +/** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +class MetadataBuilder { + + private val map: mutable.Map[String, Any] = mutable.Map.empty + + /** Returns the immutable version of this map. Used for java interop. */ + protected def getMap = map.toMap + + /** Include the content of an existing [[Metadata]] instance. */ + def withMetadata(metadata: Metadata): this.type = { + map ++= metadata.map + this + } + + /** Puts a Long. */ + def putLong(key: String, value: Long): this.type = put(key, value) + + /** Puts a Double. */ + def putDouble(key: String, value: Double): this.type = put(key, value) + + /** Puts a Boolean. */ + def putBoolean(key: String, value: Boolean): this.type = put(key, value) + + /** Puts a String. */ + def putString(key: String, value: String): this.type = put(key, value) + + /** Puts a [[Metadata]]. */ + def putMetadata(key: String, value: Metadata): this.type = put(key, value) + + /** Puts a Long array. */ + def putLongArray(key: String, value: Array[Long]): this.type = put(key, value) + + /** Puts a Double array. */ + def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value) + + /** Puts a Boolean array. */ + def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value) + + /** Puts a String array. */ + def putStringArray(key: String, value: Array[String]): this.type = put(key, value) + + /** Puts a [[Metadata]] array. */ + def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value) + + /** Builds the [[Metadata]] instance. */ + def build(): Metadata = { + new Metadata(map.toMap) + } + + private def put(key: String, value: Any): this.type = { + map.put(key, value) + this + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 430f0664b7d58..ddc3d44869c98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -96,7 +97,7 @@ class ScalaReflectionSuite extends FunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType, nullable = true), + StructField("decimalField", DecimalType.Unlimited, nullable = true), StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), @@ -199,7 +200,7 @@ class ScalaReflectionSuite extends FunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318"))) // DateType assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) @@ -211,19 +212,19 @@ class ScalaReflectionSuite extends FunSuite { assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigDecimal => DecimalType.Unlimited case _ => StringType } - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) @@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) - assert(convertToCatalyst(data) === convertedData) + val dataType = schemaFor[PrimitiveData].dataType + assert(convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) - assert(convertToCatalyst(data) === convertedData) + val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(primitiveData)) + val dataType = schemaFor[OptionalData].dataType + val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, + Row(1, 1, 1, 1, 1, 1, true)) + assert(convertToCatalyst(data, dataType) === convertedData) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7b45738c4fc95..33a3cba3d4c0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) before { @@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) val expr0 = 'a / 2 @@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType) + assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala new file mode 100644 index 0000000000000..d5b7d2789a103 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.types._ +import org.scalatest.{BeforeAndAfter, FunSuite} + +class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { + val catalog = new SimpleCatalog(false) + val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false) + + val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("f", FloatType)() + ) + + val i: Expression = UnresolvedAttribute("i") + val d1: Expression = UnresolvedAttribute("d1") + val d2: Expression = UnresolvedAttribute("d2") + val u: Expression = UnresolvedAttribute("u") + val f: Expression = UnresolvedAttribute("f") + + before { + catalog.registerTable(None, "table", relation) + } + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = Project(Seq(Alias(expression, "c")()), relation) + assert(analyzer(plan).schema.fields(0).dataType === expectedType) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + checkType(Remainder(d1, d2), DecimalType(3, 2)) + checkType(Remainder(d2, d1), DecimalType(3, 2)) + checkType(Sum(d1), DecimalType(12, 1)) + checkType(Average(d1), DecimalType(6, 5)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("unlimited decimals make everything else cast up") { + for (expr <- Seq(d1, d2, i, f, u)) { + checkType(Add(expr, u), DecimalType.Unlimited) + checkType(Subtract(expr, u), DecimalType.Unlimited) + checkType(Multiply(expr, u), DecimalType.Unlimited) + checkType(Divide(expr, u), DecimalType.Unlimited) + checkType(Remainder(expr, u), DecimalType.Unlimited) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index baeb9b0cf5964..dfa2d958c0faf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) + // Casting up to unlimited-precision decimal + widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) + + // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) + widenTest(DecimalType(2, 1), DecimalType(3, 2), None) + widenTest(DecimalType(2, 1), DoubleType, None) + widenTest(DecimalType(2, 1), IntegerType, None) + widenTest(DoubleType, DecimalType(2, 1), None) + widenTest(IntegerType, DecimalType(2, 1), None) + // StringType widenTest(NullType, StringType, Some(StringType)) widenTest(StringType, StringType, Some(StringType)) @@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite { def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == - Project(Seq(Alias(transformed, "a")()), testRelation)) + Project(Seq(Alias(transformed, "a")()), testRelation)) } // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5657bc555edf9..918996f11da2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalactic.TripleEqualsSupport.Spread @@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - actual.asInstanceOf[Double] shouldBe expected + actual.asInstanceOf[Double] shouldBe expected } test("IN") { @@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(InSet(three, nS, three +: nullS), false) checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) } - + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) @@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite { val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast DecimalType.Unlimited, null) checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) @@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) checkEvaluation("23" cast ByteType, 23.toByte) checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) @@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) @@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast IntegerType).nullable === true) assert(("abcdef" cast ShortType).nullable === true) assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType).nullable === true) + assert(("abcdef" cast DecimalType.Unlimited).nullable === true) + assert(("abcdef" cast DecimalType(4, 2)).nullable === true) assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) @@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(d1) < Literal(d2), true) } + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + + assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + + checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) + checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) + + checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) + + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -352,6 +412,8 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(d, LongType), null) checkEvaluation(Cast(d, FloatType), null) checkEvaluation(Cast(d, DoubleType), null) + checkEvaluation(Cast(d, DecimalType.Unlimited), null) + checkEvaluation(Cast(d, DecimalType(10, 2)), null) checkEvaluation(Cast(d, StringType), "1970-01-01") checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") } @@ -374,7 +436,7 @@ class ExpressionEvaluationSuite extends FunSuite { millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) - checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) // A test for higher precision than millis checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) @@ -673,7 +735,7 @@ class ExpressionEvaluationSuite extends FunSuite { val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) val d = 'a.double.at(0) - + for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala new file mode 100644 index 0000000000000..5aa263484d5ed --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.scalatest.{PrivateMethodTester, FunSuite} + +import scala.language.postfixOps + +class DecimalSuite extends FunSuite with PrivateMethodTester { + test("creating decimals") { + /** Check that a Decimal has the given string representation, precision and scale */ + def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + + checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) + checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) + checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) + checkDecimal(Decimal("10.030"), "10.030", 5, 3) + checkDecimal(Decimal(10.03), "10.03", 4, 2) + checkDecimal(Decimal(17L), "17", 20, 0) + checkDecimal(Decimal(17), "17", 10, 0) + checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1) + checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2) + checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1) + checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) + checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) + checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) + intercept[IllegalArgumentException](Decimal(170L, 2, 1)) + intercept[IllegalArgumentException](Decimal(170L, 2, 0)) + intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) + intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) + intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) + } + + test("double and long values") { + /** Check that a Decimal converts to the given double and long values */ + def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { + assert(d.toDouble === doubleValue) + assert(d.toLong === longValue) + } + + checkValues(new Decimal(), 0.0, 0L) + checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L) + checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L) + checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L) + checkValues(Decimal(10.03), 10.03, 10L) + checkValues(Decimal(17L), 17.0, 17L) + checkValues(Decimal(17), 17.0, 17L) + checkValues(Decimal(17L, 2, 1), 1.7, 1L) + checkValues(Decimal(170L, 4, 2), 1.7, 1L) + checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong) + checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong) + checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong) + checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) + checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) + checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) + checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) + checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) + } + + // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs + private val decimalVal = PrivateMethod[BigDecimal]('decimalVal) + + /** Check whether a decimal is represented compactly (passing whether we expect it to be) */ + private def checkCompact(d: Decimal, expected: Boolean): Unit = { + val isCompact = d.invokePrivate(decimalVal()).eq(null) + assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact") + } + + test("small decimals represented as unscaled long") { + checkCompact(new Decimal(), true) + checkCompact(Decimal(BigDecimal(10.03)), false) + checkCompact(Decimal(BigDecimal(1e20)), false) + checkCompact(Decimal(17L), true) + checkCompact(Decimal(17), true) + checkCompact(Decimal(17L, 2, 1), true) + checkCompact(Decimal(170L, 4, 2), true) + checkCompact(Decimal(17L, 24, 1), true) + checkCompact(Decimal(1e16.toLong), true) + checkCompact(Decimal(1e17.toLong), true) + checkCompact(Decimal(1e18.toLong - 1), true) + checkCompact(Decimal(- 1e18.toLong + 1), true) + checkCompact(Decimal(1e18.toLong - 1, 30, 10), true) + checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true) + checkCompact(Decimal(1e18.toLong), false) + checkCompact(Decimal(-1e18.toLong), false) + checkCompact(Decimal(1e18.toLong, 30, 10), false) + checkCompact(Decimal(-1e18.toLong, 30, 10), false) + checkCompact(Decimal(Long.MaxValue), false) + checkCompact(Decimal(Long.MinValue), false) + } + + test("hash code") { + assert(Decimal(123).hashCode() === (123).##) + assert(Decimal(-123).hashCode() === (-123).##) + assert(Decimal(123.312).hashCode() === (123.312).##) + assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##) + assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##) + assert(Decimal(BigDecimal(123)).hashCode() === (123).##) + + val reallyBig = BigDecimal("123182312312313232112312312123.1231231231") + assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode) + } + + test("equals") { + // The decimals on the left are stored compactly, while the ones on the right aren't + checkCompact(Decimal(123), true) + checkCompact(Decimal(BigDecimal(123)), false) + checkCompact(Decimal("123"), false) + assert(Decimal(123) === Decimal(BigDecimal(123))) + assert(Decimal(123) === Decimal(BigDecimal("123.00"))) + assert(Decimal(-123) === Decimal(BigDecimal(-123))) + assert(Decimal(-123) === Decimal(BigDecimal("-123.00"))) + } + + test("isZero") { + assert(Decimal(0).isZero) + assert(Decimal(0, 4, 2).isZero) + assert(Decimal("0").isZero) + assert(Decimal("0.000").isZero) + assert(!Decimal(1).isZero) + assert(!Decimal(1, 4, 2).isZero) + assert(!Decimal("1").isZero) + assert(!Decimal("0.001").isZero) + } + + test("arithmetic") { + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) * Decimal(-100) === Decimal(-10000)) + assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26)) + assert(Decimal(100) / Decimal(-100) === Decimal(-1)) + assert(Decimal(100) / Decimal(0) === null) + assert(Decimal(100) % Decimal(-100) === Decimal(0)) + assert(Decimal(100) % Decimal(3) === Decimal(1)) + assert(Decimal(-100) % Decimal(3) === Decimal(-1)) + assert(Decimal(100) % Decimal(0) === null) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala new file mode 100644 index 0000000000000..0063d31666c85 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.json4s.jackson.JsonMethods.parse +import org.scalatest.FunSuite + +class MetadataSuite extends FunSuite { + + val baseMetadata = new MetadataBuilder() + .putString("purpose", "ml") + .putBoolean("isBase", true) + .build() + + val summary = new MetadataBuilder() + .putLong("numFeatures", 10L) + .build() + + val age = new MetadataBuilder() + .putString("name", "age") + .putLong("index", 1L) + .putBoolean("categorical", false) + .putDouble("average", 45.0) + .build() + + val gender = new MetadataBuilder() + .putString("name", "gender") + .putLong("index", 5) + .putBoolean("categorical", true) + .putStringArray("categories", Array("male", "female")) + .build() + + val metadata = new MetadataBuilder() + .withMetadata(baseMetadata) + .putBoolean("isBase", false) // overwrite an existing key + .putMetadata("summary", summary) + .putLongArray("long[]", Array(0L, 1L)) + .putDoubleArray("double[]", Array(3.0, 4.0)) + .putBooleanArray("boolean[]", Array(true, false)) + .putMetadataArray("features", Array(age, gender)) + .build() + + test("metadata builder and getters") { + assert(age.getLong("index") === 1L) + assert(age.getDouble("average") === 45.0) + assert(age.getBoolean("categorical") === false) + assert(age.getString("name") === "age") + assert(metadata.getString("purpose") === "ml") + assert(metadata.getBoolean("isBase") === false) + assert(metadata.getMetadata("summary") === summary) + assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L)) + assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0)) + assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false)) + assert(gender.getStringArray("categories").toSeq === Seq("male", "female")) + assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender)) + } + + test("metadata json conversion") { + val json = metadata.json + withClue("toJson must produce a valid JSON string") { + parse(json) + } + val parsed = Metadata.fromJson(json) + assert(parsed === metadata) + assert(parsed.## === metadata.##) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 37e88d72b9172..c38354039d686 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -17,9 +17,7 @@ package org.apache.spark.sql.api.java; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; /** * The base type of all Spark SQL data types. @@ -54,11 +52,6 @@ public abstract class DataType { */ public static final TimestampType TimestampType = new TimestampType(); - /** - * Gets the DecimalType object. - */ - public static final DecimalType DecimalType = new DecimalType(); - /** * Gets the DoubleType object. */ @@ -151,15 +144,31 @@ public static MapType createMapType( * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and * whether values of this field can be null values ({@code nullable}). */ - public static StructField createStructField(String name, DataType dataType, boolean nullable) { + public static StructField createStructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { if (name == null) { throw new IllegalArgumentException("name should not be null."); } if (dataType == null) { throw new IllegalArgumentException("dataType should not be null."); } + if (metadata == null) { + throw new IllegalArgumentException("metadata should not be null."); + } - return new StructField(name, dataType, nullable); + return new StructField(name, dataType, nullable, metadata); + } + + /** + * Creates a StructField with empty metadata. + * + * @see #createStructField(String, DataType, boolean, Metadata) + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + return createStructField(name, dataType, nullable, (new MetadataBuilder()).build()); } /** @@ -191,5 +200,4 @@ public static StructType createStructType(StructField[] fields) { return new StructType(fields); } - } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java index bc54c078d7a4e..60752451ecfc7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -19,9 +19,61 @@ /** * The data type representing java.math.BigDecimal values. - * - * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. */ public class DecimalType extends DataType { - protected DecimalType() {} + private boolean hasPrecisionInfo; + private int precision; + private int scale; + + public DecimalType(int precision, int scale) { + this.hasPrecisionInfo = true; + this.precision = precision; + this.scale = scale; + } + + public DecimalType() { + this.hasPrecisionInfo = false; + this.precision = -1; + this.scale = -1; + } + + public boolean isUnlimited() { + return !hasPrecisionInfo; + } + + public boolean isFixed() { + return hasPrecisionInfo; + } + + /** Return the precision, or -1 if no precision is set */ + public int getPrecision() { + return precision; + } + + /** Return the scale, or -1 if no precision is set */ + public int getScale() { + return scale; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DecimalType that = (DecimalType) o; + + if (hasPrecisionInfo != that.hasPrecisionInfo) return false; + if (precision != that.precision) return false; + if (scale != that.scale) return false; + + return true; + } + + @Override + public int hashCode() { + int result = (hasPrecisionInfo ? 1 : 0); + result = 31 * result + precision; + result = 31 * result + scale; + return result; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java new file mode 100644 index 0000000000000..0f819fb01a76a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use [[MetadataBuilder]]. + */ +class Metadata extends org.apache.spark.sql.catalyst.util.Metadata { + Metadata(scala.collection.immutable.Map map) { + super(map); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java new file mode 100644 index 0000000000000..6e6b12f0722c5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +public class MetadataBuilder extends org.apache.spark.sql.catalyst.util.MetadataBuilder { + @Override + public Metadata build() { + return new Metadata(getMap()); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java index b48e2a2c5f953..7c60d492bcdf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java; +import java.util.Map; + /** * A StructField object represents a field in a StructType object. * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, @@ -24,20 +26,27 @@ * The field of {@code dataType} specifies the data type of a StructField. * The field of {@code nullable} specifies if values of a StructField can contain {@code null} * values. + * The field of {@code metadata} provides extra information of the StructField. * * To create a {@link StructField}, - * {@link DataType#createStructField(String, DataType, boolean)} + * {@link DataType#createStructField(String, DataType, boolean, Metadata)} * should be used. */ public class StructField { private String name; private DataType dataType; private boolean nullable; + private Metadata metadata; - protected StructField(String name, DataType dataType, boolean nullable) { + protected StructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { this.name = name; this.dataType = dataType; this.nullable = nullable; + this.metadata = metadata; } public String getName() { @@ -52,6 +61,10 @@ public boolean isNullable() { return nullable; } + public Metadata getMetadata() { + return metadata; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -62,6 +75,7 @@ public boolean equals(Object o) { if (nullable != that.nullable) return false; if (!dataType.equals(that.dataType)) return false; if (!name.equals(that.name)) return false; + if (!metadata.equals(that.metadata)) return false; return true; } @@ -71,6 +85,7 @@ public int hashCode() { int result = name.hashCode(); result = 31 * result + dataType.hashCode(); result = 31 * result + (nullable ? 1 : 0); + result = 31 * result + metadata.hashCode(); return result; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java new file mode 100644 index 0000000000000..b751847b464fd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * ::DeveloperApi:: + * The data type representing User-Defined Types (UDTs). + * UDTs may use any other DataType for an underlying representation. + */ +@DeveloperApi +public abstract class UserDefinedType extends DataType implements Serializable { + + protected UserDefinedType() { } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserDefinedType that = (UserDefinedType) o; + return this.sqlType().equals(that.sqlType()); + } + + /** Underlying storage type for this UDT */ + public abstract DataType sqlType(); + + /** Convert the user type to a SQL datum */ + public abstract Object serialize(Object obj); + + /** Convert a SQL datum to the user type */ + public abstract UserType deserialize(Object datum); + + /** Class object for the UserType */ + public abstract Class userClass(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3ced11a5e6c11..2e7abac1f1bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -103,6 +103,19 @@ private[sql] trait CacheManager { cachedData.remove(dataIndex) } + /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ + private[sql] def tryUncacheQuery( + query: SchemaRDD, + blocking: Boolean = true): Boolean = writeLock { + val planToCache = query.queryExecution.analyzed + val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) + val found = dataIndex >= 0 + if (found) { + cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData.remove(dataIndex) + } + found + } /** Optionally returns cached data for the given SchemaRDD */ private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 07e6e2eccddf4..279495aa64755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -79,13 +79,13 @@ private[sql] trait SQLConf { private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "snappy") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -106,10 +106,10 @@ private[sql] trait SQLConf { * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a41a500c9a5d0..31cc4170aa867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,10 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation} /** * :: AlphaComponent :: @@ -69,13 +70,19 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + @transient + protected[sql] val ddlParser = new DDLParser + @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser new catalyst.SparkSQLParser(fallback(_)) } - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = { + ddlParser(sql).getOrElse(sqlParser(sql)) + } + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -101,8 +108,14 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, - LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) + val attributeSeq = ScalaReflection.attributesFor[A] + val schema = StructType.fromAttributes(attributeSeq) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + } + + implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { + logicalPlanToSparkQuery(LogicalRelation(baseRelation)) } /** @@ -266,6 +279,19 @@ class SQLContext(@transient val sparkContext: SparkContext) catalog.registerTable(None, tableName, rdd.queryExecution.logical) } + /** + * Drops the temporary table with the given table name in the catalog. If the table has been + * cached/persisted before, it's also unpersisted. + * + * @param tableName the name of the table to be unregistered. + * + * @group userf + */ + def dropTempTable(tableName: String): Unit = { + tryUncacheQuery(table(tableName)) + catalog.unregisterTable(None, tableName) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. @@ -284,6 +310,14 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + /** + * :: DeveloperApi :: + * Allows extra strategies to be injected into the query planner at runtime. Note this API + * should be consider experimental and is not intended to be stable across releases. + */ + @DeveloperApi + var extraStrategies: Seq[Strategy] = Nil + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -294,7 +328,9 @@ class SQLContext(@transient val sparkContext: SparkContext) def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = + extraStrategies ++ ( CommandStrategy(self) :: + DataSourceStrategy :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: @@ -303,7 +339,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ParquetOperations :: BasicOperators :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil + BroadcastNestedLoopJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using @@ -408,7 +444,7 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} - |Code Generation: ${executedPlan.codegenEnabled} + |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} |== RDD == """.stripMargin.trim } @@ -448,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8b96df10963b3..fbec2f9f4b2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,25 +17,26 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList} +import java.util.{List => JList} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.api.python.SerDeUtil import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.storage.StorageLevel /** * :: AlphaComponent :: @@ -113,18 +114,22 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(_.copy()) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) override def getPartitions: Array[Partition] = firstParent[Row].partitions - override protected def getDependencies: Seq[Dependency[_]] = + override protected def getDependencies: Seq[Dependency[_]] = { + schema // Force reification of the schema so it is available on executors. + List(new OneToOneDependency(queryExecution.toRdd)) + } - /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - def schema: StructType = queryExecution.analyzed.schema + /** + * Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + lazy val schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -382,12 +387,8 @@ class SchemaRDD( */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) - this.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)) - } + val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 15516afb95504..fd5f4abcbcd65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { - @transient val sqlContext: SQLContext + @transient def sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 595b4aa36eae3..6d4c0d82ac7af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration { s""" def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration { // scalastyle:off def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 082ae03eef03f..4c0869e05b029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,12 +23,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -39,6 +41,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, LogicalRelation(baseRelation)) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. @@ -86,9 +92,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { /** * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. */ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { - val schema = getSchema(beanClass) + val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. @@ -99,11 +108,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { iter.map { row => new GenericRow( - extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any] + extractors.zip(attributeSeq).map { case (e, attr) => + DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + }.toArray[Any] ): ScalaRow } } - new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext)) } /** @@ -190,14 +201,21 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName) } - /** Returns a Catalyst Schema for the given java bean class. */ + /** + * Returns a Catalyst Schema for the given java bean class. + */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. val beanInfo = Introspector.getBeanInfo(beanClass) + // Note: The ordering of elements may differ from when the schema is inferred in Scala. + // This is because beanInfo.getPropertyDescriptors gives no guarantees about + // element ordering. val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (org.apache.spark.sql.StringType, true) case c: Class[_] if c == java.lang.Short.TYPE => @@ -230,7 +248,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { case c: Class[_] if c == classOf[java.lang.Boolean] => (org.apache.spark.sql.BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => - (org.apache.spark.sql.DecimalType, true) + (org.apache.spark.sql.DecimalType(), true) case c: Class[_] if c == classOf[java.sql.Date] => (org.apache.spark.sql.DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 1e0ccb368a276..78e8d908fe0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -47,6 +47,9 @@ class JavaSchemaRDD( private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan) + /** Returns the underlying Scala SchemaRDD. */ + val schemaRDD: SchemaRDD = baseSchemaRDD + override val classTag = scala.reflect.classTag[Row] override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index df01411f60a05..401798e317e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.annotation.varargs import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions @@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { } override def hashCode(): Int = row.hashCode() + + override def toString: String = row.toString } object Row { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala new file mode 100644 index 0000000000000..a7d0f4f127ecc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType} +import org.apache.spark.sql.{DataType => ScalaDataType} +import org.apache.spark.sql.types.util.DataTypeConversions + +/** + * Scala wrapper for a Java UserDefinedType + */ +private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType]) + extends ScalaUserDefinedType[UserType] with Serializable { + + /** Underlying storage type for this UDT */ + val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType()) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): Any = javaUDT.serialize(obj) + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = javaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = javaUDT.userClass() +} + +/** + * Java wrapper for a Scala UserDefinedType + */ +private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType]) + extends UserDefinedType[UserType] with Serializable { + + /** Underlying storage type for this UDT */ + val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object] + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = scalaUDT.userClass +} + +private[sql] object UDTWrappers { + + def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = { + udtType match { + case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT + case _ => new JavaToScalaUDTWrapper(udtType) + } + } + + def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = { + udtType match { + case t: JavaToScalaUDTWrapper[_] => t.javaUDT + case _ => new ScalaToJavaUDTWrapper(udtType) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 300cef15bf8a4..c68dceef3b142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -79,8 +79,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( + columnStats: ColumnStats, columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) + extends BasicColumnBuilder[T, JvmType](columnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: NativeType]( @@ -91,7 +92,7 @@ private[sql] abstract class NativeColumnBuilder[T <: NativeType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN) +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) @@ -112,10 +113,11 @@ private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnS private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b9f9f8270045c..668efe4a3b2a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -70,11 +70,30 @@ private[sql] sealed trait ColumnStats extends Serializable { def collectedStatistics: Row } +/** + * A no-op ColumnStats only used for testing purposes. + */ private[sql] class NoopColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) + + def collectedStatistics = Row(null, null, nullCount, count, 0L) +} - override def gatherStats(row: Row, ordinal: Int): Unit = {} +private[sql] class BooleanColumnStats extends ColumnStats { + protected var upper = false + protected var lower = true - override def collectedStatistics = Row() + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getBoolean(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + } + } + + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -229,3 +248,25 @@ private[sql] class TimestampColumnStats extends ColumnStats { def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } + +private[sql] class BinaryColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += BINARY.actualSize(row, ordinal) + } + } + + def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) +} + +private[sql] class GenericColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += GENERIC.actualSize(row, ordinal) + } + } + + def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index ee63134f56d8c..455b415d9d959 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -161,6 +161,9 @@ private[sql] case class InMemoryRelation( } def cachedColumnBuffers = _cachedColumnBuffers + + override protected def otherCopyArgs: Seq[AnyRef] = + Seq(_cachedColumnBuffers, statisticsToBePropagated) } private[sql] case class InMemoryColumnarTableScan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 04c51a1ee4b97..ed6b95dc6d9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,29 +19,32 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataType, StructType, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.types.UserDefinedType /** * :: DeveloperApi :: */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty } else { val bufferedIterator = iterator.buffered val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - + val schemaFields = schema.fields.toArray bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + mutableRow(i) = + ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) i += 1 } @@ -50,12 +53,6 @@ object RDDConversions { } } } - - /* - def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { - LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } - */ } case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b3edd5020fa8c..087b0ecbb25c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -70,16 +70,29 @@ case class GeneratedAggregate( val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val result = currentCount AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case Sum(expr) => - val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() - val initialValue = Cast(Literal(0L), expr.dataType) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) + case _ => + expr.dataType + } + + val currentSum = AttributeReference("currentSum", resultType, nullable = false)() + val initialValue = Cast(Literal(0L), resultType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... @@ -93,10 +106,26 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() val initialCount = Literal(0L) val initialSum = Cast(Literal(0L), expr.dataType) - val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } + + val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) - val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType)) AggregateEvaluation( currentCount :: currentSum :: Nil, @@ -142,7 +171,7 @@ case class GeneratedAggregate( val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap: Map[TreeNodeRef, Expression] = + val resultMap: Map[TreeNodeRef, Expression] = aggregatesToCompute.zip(computeFunctions).map { case (agg, func) => new TreeNodeRef(agg) -> func.result }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b1a7948b66cb6..81c60e00505c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD - - import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -82,7 +80,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = execute().map(_.copy()).collect() + def executeCollect(): Array[Row] = + execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 077e6ebc5f11e..84d96e612f0dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[LongHashSet], new LongHashSetSerializer) kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) + kryo.register(classOf[Decimal]) kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..cc7e0c05ffc70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), + StructType.fromAttributes(output))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -304,6 +305,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.SetCommand(kv) => Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 977f3c9f32096..1b8ba3ace2a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) } override def execute() = { @@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) override def output = child.output override def outputPartitioning = SinglePartition - val ordering = new RowOrdering(sortOrder, child.output) + val ord = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + .map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5859eba408ee1..f23b9c48cfb40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,10 +21,12 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.sql.{SQLConf, SQLContext} +// TODO: DELETE ME... trait Command { this: SparkPlan => @@ -44,6 +46,35 @@ trait Command { override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } +// TODO: Replace command with runnable command. +trait RunnableCommand extends logical.Command { + self: Product => + + def output: Seq[Attribute] + def run(sqlContext: SQLContext): Seq[Row] +} + +case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) + + override def output = cmd.output + + override def children = Nil + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) +} + /** * :: DeveloperApi :: */ @@ -53,50 +84,35 @@ case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribut extends LeafNode with Command with Logging { override protected lazy val sideEffectResult: Seq[Row] = kv match { - // Set value for the key. - case Some((key, Some(value))) => - if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + // Configures the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) - } else { - context.setConf(key, value) - Seq(Row(s"$key=$value")) - } - - // Query the value bound to the key. + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) + + // Configures a single property. + case Some((key, Some(value))) => + context.setConf(key, value) + Seq(Row(s"$key=$value")) + + // Queries all key-value pairs that are set in the SQLConf of the context. Notice that different + // from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties + // while "SET -v" returns all properties.) + case Some(("-v", None)) | None => + context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq + + // Queries the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) + + // Queries a single property. case Some((key, None)) => - // TODO (lian) This is just a workaround to make the Simba ODBC driver work. - // Should remove this once we get the ODBC driver updated. - if (key == "-v") { - val hiveJars = Seq( - "hive-exec-0.12.0.jar", - "hive-service-0.12.0.jar", - "hive-common-0.12.0.jar", - "hive-hwi-0.12.0.jar", - "hive-0.12.0.jar").mkString(":") - - context.getAllConfs.map { case (k, v) => - Row(s"$k=$v") - }.toSeq ++ Seq( - Row("system:java.class.path=" + hiveJars), - Row("system:sun.java.command=shark.SharkServer2")) - } else { - if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) - } else { - Seq(Row(s"$key=${context.getConf(key, "")}")) - } - } - - // Query all key-value pairs that are set in the SQLConf of the context. - case _ => - context.getAllConfs.map { case (k, v) => - Row(s"$k=$v") - }.toSeq + Seq(Row(s"$key=${context.getConf(key, "")}")) } override def otherCopyArgs = context :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8fd35880eedfe..5cf2a785adc7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -49,7 +49,8 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - val input: Array[Row] = buildPlan.executeCollect() + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index a1961bba1899e..a83cf5d441d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -116,7 +118,7 @@ object EvaluatePython { def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Row, struct: StructType) => + case (row: Seq[Any], struct: StructType) => val fields = struct.fields.map(field => field.dataType) row.zip(fields).map { case (obj, dataType) => toJava(obj, dataType) @@ -133,6 +135,10 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal + // Pyrolite can handle Timestamp case (other, _) => other } @@ -173,6 +179,9 @@ object EvaluatePython { case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala new file mode 100644 index 0000000000000..fc70c183437f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ + +private[sql] class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio)(sqlContext) + } +} + +private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( + @transient val sqlContext: SQLContext) + extends TableScan { + + private def baseRDD = sqlContext.sparkContext.textFile(fileName) + + override val schema = + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord) + + override def buildScan() = + JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 047dc85df6c1d..d9d7a3fea3963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.json +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types.util.DataTypeConversions + import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal @@ -71,16 +74,18 @@ private[sql] object JsonRDD extends Logging { def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { val (topLevel, structLike) = values.partition(_.size == 1) + val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { case ArrayType(elementType, _) => { def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => false + case s: StructType => true case ArrayType(t1, _) => hasInnerStruct(t1) - case o => true + case o => false } - hasInnerStruct(elementType) + // Check if this array has inner struct. + !hasInnerStruct(elementType) } case struct: StructType => false case _ => true @@ -88,8 +93,11 @@ private[sql] object JsonRDD extends Logging { }.map { a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) } + val topLevelFieldNameSet = topLevelFields.map(_.name) - val structFields: Seq[StructField] = structLike.groupBy(_(0)).map { + val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { + case (name, _) => !topLevelFieldNameSet.contains(name) + }.map { case (name, fields) => { val nestedFields = fields.map(_.tail) val structType = makeStruct(nestedFields, prefix :+ name) @@ -117,10 +125,7 @@ private[sql] object JsonRDD extends Logging { } }.flatMap(field => field).toSeq - StructType( - (topLevelFields ++ structFields).sortBy { - case StructField(name, _, _) => name - }) + StructType((topLevelFields ++ structFields).sortBy(_.name)) } makeStruct(resolved.keySet.toSeq, Nil) @@ -128,7 +133,7 @@ private[sql] object JsonRDD extends Logging { private[sql] def nullTypeToStringType(struct: StructType): StructType = { val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { + case StructField(fieldName, dataType, nullable, _) => { val newType = dataType match { case NullType => StringType case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) @@ -163,9 +168,7 @@ private[sql] object JsonRDD extends Logging { StructField(name, dataType, true) } } - StructType(newFields.toSeq.sortBy { - case StructField(name, _, _) => name - }) + StructType(newFields.toSeq.sortBy(_.name)) } case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) @@ -180,9 +183,9 @@ private[sql] object JsonRDD extends Logging { ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigDecimal => DecimalType.Unlimited // Unexpected data type. case _ => StringType } @@ -324,13 +327,13 @@ private[sql] object JsonRDD extends Logging { } } - private def toDecimal(value: Any): BigDecimal = { + private def toDecimal(value: Any): Decimal = { value match { - case value: java.lang.Integer => BigDecimal(value) - case value: java.lang.Long => BigDecimal(value) - case value: java.math.BigInteger => BigDecimal(value) - case value: java.lang.Double => BigDecimal(value) - case value: java.math.BigDecimal => BigDecimal(value) + case value: java.lang.Integer => Decimal(value) + case value: java.lang.Long => Decimal(value) + case value: java.math.BigInteger => Decimal(BigDecimal(value)) + case value: java.lang.Double => Decimal(value) + case value: java.math.BigDecimal => Decimal(BigDecimal(value)) } } @@ -357,7 +360,8 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toString(value)}""") + val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) + builder.append(s"""\"${key}\":${stringValue}""") } builder.append("}") @@ -375,7 +379,7 @@ private[sql] object JsonRDD extends Logging { private def toDate(value: Any): Date = { value match { // only support string as date - case value: java.lang.String => Date.valueOf(value) + case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) } } @@ -383,7 +387,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => Timestamp.valueOf(value) + case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) } } @@ -396,7 +400,7 @@ private[sql] object JsonRDD extends Logging { case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) case DoubleType => toDouble(value) - case DecimalType => toDecimal(value) + case DecimalType() => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null @@ -413,7 +417,7 @@ private[sql] object JsonRDD extends Logging { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _), i) => + case (StructField(name, dataType, _, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( enforceCorrectType(_, dataType)).orNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e98d151286818..51dad54f1a3f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.SparkPlan /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -125,6 +126,9 @@ package object sql { @DeveloperApi type DataType = catalyst.types.DataType + @DeveloperApi + val DataType = catalyst.types.DataType + /** * :: DeveloperApi :: * @@ -180,6 +184,20 @@ package object sql { * * The data type representing `scala.math.BigDecimal` values. * + * TODO(matei): explain precision and scale + * + * @group dataType + */ + @DeveloperApi + type DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * TODO(matei): explain precision and scale + * * @group dataType */ @DeveloperApi @@ -414,4 +432,32 @@ package object sql { */ @DeveloperApi val StructField = catalyst.types.StructField + + /** + * Converts a logical plan into zero or more SparkPlans. + */ + @DeveloperApi + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + + /** + * :: DeveloperApi :: + * + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata$#fromJson]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ + @DeveloperApi + type Metadata = catalyst.util.Metadata + + /** + * :: DeveloperApi :: + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ + @DeveloperApi + type MetadataBuilder = catalyst.util.MetadataBuilder } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 2fc7e1cf23ab7..1bbb66aaa19a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} @@ -75,6 +77,9 @@ private[sql] object CatalystConverter { parent: CatalystConverter): Converter = { val fieldType: DataType = field.dataType fieldType match { + case udt: UserDefinedType[_] => { + createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) + } // For native JVM types we use a converter with native arrays case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) @@ -117,6 +122,12 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case d: DecimalType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateDecimal(fieldIndex, value, d) + } + } // All other primitive types use the default converter case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) @@ -191,6 +202,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) + } + protected[parquet] def isRootConverter: Boolean = parent == null protected[parquet] def clearBuffer(): Unit @@ -201,6 +216,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { * @return */ def getCurrentRecord: Row = throw new UnsupportedOperationException + + /** + * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in + * a long (i.e. precision <= 18) + */ + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + val precision = ctype.precisionInfo.get.precision + val scale = ctype.precisionInfo.get.scale + val bytes = value.getBytes + require(bytes.length <= 16, "Decimal field too large to read") + var unscaled = 0L + var i = 0 + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xFF) + i += 1 + } + // Make sure unscaled has the right sign, by sign-extending the first bit + val numBits = 8 * bytes.length + unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) + dest.set(unscaled, precision, scale) + } } /** @@ -222,8 +258,8 @@ private[parquet] class CatalystGroupConverter( schema, index, parent, - current=null, - buffer=new ArrayBuffer[Row]( + current = null, + buffer = new ArrayBuffer[Row]( CatalystArrayConverter.INITIAL_ARRAY_SIZE)) /** @@ -268,7 +304,7 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { - assert(current!=null) // there should be no empty groups + assert(current != null) // there should be no empty groups buffer.append(new GenericRow(current.toArray)) parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) } @@ -325,7 +361,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override def end(): Unit = {} - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = current.setBoolean(fieldIndex, value) @@ -352,6 +388,16 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = current.setString(fieldIndex, value.toStringUsingUTF8) + + override protected[parquet] def updateDecimal( + fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + var decimal = current(fieldIndex).asInstanceOf[Decimal] + if (decimal == null) { + decimal = new Decimal + current(fieldIndex) = decimal + } + readDecimal(decimal, value, ctype) + } } /** @@ -490,7 +536,7 @@ private[parquet] class CatalystNativeArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = throw new UnsupportedOperationException - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { checkGrowBuffer() buffer(elements) = value.asInstanceOf[NativeType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 7c83f1cad7d71..1e67799e8399a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -18,32 +18,40 @@ package org.apache.spark.sql.parquet import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} import org.apache.hadoop.conf.Configuration -import parquet.filter._ -import parquet.filter.ColumnPredicates._ +import parquet.common.schema.ColumnPath +import parquet.filter2.compat.FilterCompat +import parquet.filter2.compat.FilterCompat._ +import parquet.filter2.predicate.Operators.{Column, SupportsLtGt} +import parquet.filter2.predicate.{FilterApi, FilterPredicate} +import parquet.filter2.predicate.FilterApi._ +import parquet.io.api.Binary import parquet.column.ColumnReader import com.google.common.io.BaseEncoding import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.parquet.ParquetColumns._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" // set this to false if pushdown should be disabled val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" - def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = { + def createRecordFilter(filterExpressions: Seq[Expression]): Filter = { val filters: Seq[CatalystFilter] = filterExpressions.collect { case (expression: Expression) if createFilter(expression).isDefined => createFilter(expression).get } - if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null + if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null } def createFilter(expression: Expression): Option[CatalystFilter] = { @@ -52,78 +60,188 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case BooleanType => - ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate) + ComparisonFilter.createBooleanEqualityFilter( + name, + literal.value.asInstanceOf[Boolean], + predicate) + case ByteType => + new ComparisonFilter( + name, + FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x == literal.value.asInstanceOf[Int], + FilterApi.eq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x == literal.value.asInstanceOf[Long], + FilterApi.eq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x == literal.value.asInstanceOf[Double], + FilterApi.eq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x == literal.value.asInstanceOf[Float], + FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) case StringType => - ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate) + ComparisonFilter.createStringEqualityFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryEqualityFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.eq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } + def createLessThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { - case IntegerType => - ComparisonFilter.createIntFilter( + case ByteType => + new ComparisonFilter( + name, + FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( name, - (x: Int) => x < literal.value.asInstanceOf[Int], + FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) + case IntegerType => + new ComparisonFilter( + name, + FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x < literal.value.asInstanceOf[Long], + FilterApi.lt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x < literal.value.asInstanceOf[Double], + FilterApi.lt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( + name, + FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), + predicate) + case StringType => + ComparisonFilter.createStringLessThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( name, - (x: Float) => x < literal.value.asInstanceOf[Float], + FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.lt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), predicate) } def createLessThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x <= literal.value.asInstanceOf[Int], + FilterApi.ltEq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x <= literal.value.asInstanceOf[Long], + FilterApi.ltEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x <= literal.value.asInstanceOf[Double], + FilterApi.ltEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x <= literal.value.asInstanceOf[Float], + FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), + predicate) + case StringType => + ComparisonFilter.createStringLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.ltEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), predicate) } // TODO: combine these two types somehow? @@ -131,49 +249,122 @@ private[sql] object ParquetFilters { name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x > literal.value.asInstanceOf[Int], + FilterApi.gt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x > literal.value.asInstanceOf[Long], + FilterApi.gt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x > literal.value.asInstanceOf[Double], + FilterApi.gt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( + name, + FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), + predicate) + case StringType => + ComparisonFilter.createStringGreaterThanFilter( name, - (x: Float) => x > literal.value.asInstanceOf[Float], + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), predicate) } def createGreaterThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - ComparisonFilter.createIntFilter( - name, (x: Int) => x >= literal.value.asInstanceOf[Int], + new ComparisonFilter( + name, + FilterApi.gtEq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x >= literal.value.asInstanceOf[Long], + FilterApi.gtEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x >= literal.value.asInstanceOf[Double], + FilterApi.gtEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x >= literal.value.asInstanceOf[Float], + FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), + predicate) + case StringType => + ComparisonFilter.createStringGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gtEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), predicate) } @@ -209,25 +400,25 @@ private[sql] object ParquetFilters { case _ => None } } - case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => + case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType => Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable => + case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType => Some(createEqualityFilter(left.name, right, p)) - case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => + case p @ LessThan(left: Literal, right: NamedExpression) => Some(createLessThanFilter(right.name, left, p)) - case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable => + case p @ LessThan(left: NamedExpression, right: Literal) => Some(createLessThanFilter(left.name, right, p)) - case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + case p @ LessThanOrEqual(left: Literal, right: NamedExpression) => Some(createLessThanOrEqualFilter(right.name, left, p)) - case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + case p @ LessThanOrEqual(left: NamedExpression, right: Literal) => Some(createLessThanOrEqualFilter(left.name, right, p)) - case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable => + case p @ GreaterThan(left: Literal, right: NamedExpression) => Some(createGreaterThanFilter(right.name, left, p)) - case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable => + case p @ GreaterThan(left: NamedExpression, right: Literal) => Some(createGreaterThanFilter(left.name, right, p)) - case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) => Some(createGreaterThanOrEqualFilter(right.name, left, p)) - case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) => Some(createGreaterThanOrEqualFilter(left.name, right, p)) case _ => None } @@ -300,142 +491,206 @@ private[sql] object ParquetFilters { } abstract private[parquet] class CatalystFilter( - @transient val predicate: CatalystPredicate) extends UnboundRecordFilter + @transient val predicate: CatalystPredicate) extends FilterPredicate private[parquet] case class ComparisonFilter( val columnName: String, - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient override val predicate: CatalystPredicate) extends CatalystFilter(predicate) { - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor) } } private[parquet] case class OrFilter( - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient val left: CatalystFilter, @transient val right: CatalystFilter, @transient override val predicate: Or) extends CatalystFilter(predicate) { def this(l: CatalystFilter, r: CatalystFilter) = this( - OrRecordFilter.or(l, r), + FilterApi.or(l, r), l, r, Or(l.predicate, r.predicate)) - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor); } + } private[parquet] case class AndFilter( - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient val left: CatalystFilter, @transient val right: CatalystFilter, @transient override val predicate: And) extends CatalystFilter(predicate) { def this(l: CatalystFilter, r: CatalystFilter) = this( - AndRecordFilter.and(l, r), + FilterApi.and(l, r), l, r, And(l.predicate, r.predicate)) - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor); } + } private[parquet] object ComparisonFilter { - def createBooleanFilter( + def createBooleanEqualityFilter( columnName: String, value: Boolean, predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToBoolean( - new BooleanPredicateFunction { - def functionToApply(input: Boolean): Boolean = input == value - } - )), + FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), predicate) - def createStringFilter( + def createStringEqualityFilter( columnName: String, value: String, predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToString ( - new ColumnPredicates.PredicateFunction[String] { - def functionToApply(input: String): Boolean = input == value - } - )), + FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), predicate) - def createIntFilter( + def createStringLessThanFilter( columnName: String, - func: Int => Boolean, + value: String, predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToInteger( - new IntegerPredicateFunction { - def functionToApply(input: Int) = func(input) - } - )), + FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)), predicate) - def createLongFilter( + def createStringLessThanOrEqualFilter( columnName: String, - func: Long => Boolean, + value: String, predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToLong( - new LongPredicateFunction { - def functionToApply(input: Long) = func(input) - } - )), + FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)), predicate) - def createDoubleFilter( + def createStringGreaterThanFilter( columnName: String, - func: Double => Boolean, + value: String, predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToDouble( - new DoublePredicateFunction { - def functionToApply(input: Double) = func(input) - } - )), + FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)), predicate) - def createFloatFilter( + def createStringGreaterThanOrEqualFilter( columnName: String, - func: Float => Boolean, + value: String, predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToFloat( - new FloatPredicateFunction { - def functionToApply(input: Float) = func(input) - } - )), + FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)), predicate) + + def createBinaryEqualityFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) +} + +private[spark] object ParquetColumns { + + def byteColumn(columnPath: String): ByteColumn = { + new ByteColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ByteColumn(columnPath: ColumnPath) + extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt + + def shortColumn(columnPath: String): ShortColumn = { + new ShortColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ShortColumn(columnPath: ColumnPath) + extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt + + + def dateColumn(columnPath: String): DateColumn = { + new DateColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DateColumn(columnPath: ColumnPath) + extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt + + def timestampColumn(columnPath: String): TimestampColumn = { + new TimestampColumn(ColumnPath.fromDotString(columnPath)) + } + + final class TimestampColumn(columnPath: ColumnPath) + extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt + + def decimalColumn(columnPath: String): DecimalColumn = { + new DecimalColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DecimalColumn(columnPath: ColumnPath) + extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt + + final class WrappedDate(val date: Date) extends Comparable[WrappedDate] { + + override def compareTo(other: WrappedDate): Int = { + date.compareTo(other.date) + } + } + + final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] { + + override def compareTo(other: WrappedTimestamp): Int = { + timestamp.compareTo(other.timestamp) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5c6fa78ae3895..74c43e053b03c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -38,10 +38,12 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData +import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.util.ContextUtil import parquet.io.ParquetDecodingException import parquet.schema.MessageType +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLConf @@ -76,6 +78,10 @@ case class ParquetTableScan( s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { + import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import parquet.filter2.compat.FilterCompat.Filter + import parquet.filter2.predicate.FilterPredicate + val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) @@ -106,13 +112,19 @@ case class ParquetTableScan( // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf. if (columnPruningPred.length > 0 && sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { - ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) + + // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering + val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred) + if (filter != null){ + val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate() + ParquetInputFormat.setFilterPredicate(conf, filterPredicate) + } } // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.set( SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( @@ -362,15 +374,17 @@ private[parquet] class FilteringParquetRowInputFormat override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + + import parquet.filter2.compat.FilterCompat.NoOpFilter + import parquet.filter2.compat.FilterCompat.Filter + val readSupport: ReadSupport[Row] = new RowReadSupport() - val filterExpressions = - ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext)) - if (filterExpressions.length > 0) { - logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}") + val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) + if (!filter.isInstanceOf[NoOpFilter]) { new ParquetRecordReader[Row]( readSupport, - ParquetFilters.createRecordFilter(filterExpressions)) + filter) } else { new ParquetRecordReader[Row](readSupport) } @@ -381,7 +395,7 @@ private[parquet] class FilteringParquetRowInputFormat if (footers eq null) { val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap if (statuses.isEmpty) { @@ -423,10 +437,8 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { - import FilteringParquetRowInputFormat.blockLocationCache - - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) - + // Use task side strategy by default + val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -435,23 +447,67 @@ private[parquet] class FilteringParquetRowInputFormat s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + s" minSplitSize = $minSplitSize") } - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + + // Uses strict type checking by default val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] - // if parquet file is empty, return empty splits. - if (globalMetaData == null) { - return splits - } + if (globalMetaData == null) { + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + return splits + } + val readContext = getReadSupport(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData(), globalMetaData.getSchema())) + + if (taskSideMetaData){ + logInfo("Using Task Side Metadata Split Strategy") + return getTaskSideSplits(configuration, + footers, + maxSplitSize, + minSplitSize, + readContext) + } else { + logInfo("Using Client Side Metadata Split Strategy") + return getClientSideSplits(configuration, + footers, + maxSplitSize, + minSplitSize, + readContext) + } + + } + + def getClientSideSplits( + configuration: Configuration, + footers: JList[Footer], + maxSplitSize: JLong, + minSplitSize: JLong, + readContext: ReadContext): JList[ParquetInputSplit] = { + + import FilteringParquetRowInputFormat.blockLocationCache + import parquet.filter2.compat.FilterCompat; + import parquet.filter2.compat.FilterCompat.Filter; + import parquet.filter2.compat.RowGroupFilter; + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + val filter: Filter = ParquetInputFormat.getFilter(configuration) + var rowGroupsDropped: Long = 0 + var totalRowGroups: Long = 0 + + // Ugly hack, stuck with it until PR: + // https://github.com/apache/incubator-parquet-mr/pull/17 + // is resolved val generateSplits = - classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get + Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") + .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( + sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) for (footer <- footers) { @@ -460,29 +516,85 @@ private[parquet] class FilteringParquetRowInputFormat val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - var blockLocations: Array[BlockLocation] = null - if (!cacheMetadata) { - blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - } else { - blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { - def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) - }) - } + totalRowGroups = totalRowGroups + blocks.size + val filteredBlocks = RowGroupFilter.filterRowGroups( + filter, + blocks, + parquetMetaData.getFileMetaData.getSchema) + rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) + + if (!filteredBlocks.isEmpty){ + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } + splits.addAll( + generateSplits.invoke( + null, + filteredBlocks, + blockLocations, + status, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + } + } + + if (rowGroupsDropped > 0 && totalRowGroups > 0){ + val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt + logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate " + + s"($percentDropped %) !") + } + else { + logInfo("There were no row groups that could be dropped due to filter predicates") + } + splits + + } + + def getTaskSideSplits( + configuration: Configuration, + footers: JList[Footer], + maxSplitSize: JLong, + minSplitSize: JLong, + readContext: ReadContext): JList[ParquetInputSplit] = { + + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + + // Ugly hack, stuck with it until PR: + // https://github.com/apache/incubator-parquet-mr/pull/17 + // is resolved + val generateSplits = + Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") + .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( + sys.error( + s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) + generateSplits.setAccessible(true) + + for (footer <- footers) { + val file = footer.getFile + val fs = file.getFileSystem(configuration) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) splits.addAll( generateSplits.invoke( - null, - blocks, - blockLocations, - status, - parquetMetaData.getFileMetaData, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + null, + blockLocations, + status, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) } splits - } + } + } private[parquet] object FilteringParquetRowInputFormat { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index bdf02401b21be..7bc249660053a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -30,6 +30,7 @@ import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -173,6 +174,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { + case t: UserDefinedType[_] => writeValue(t.sqlType, value) case t @ ArrayType(_, _) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) @@ -204,6 +206,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -283,6 +290,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } writer.endGroup() } + + // Scratch array used to write decimals as fixed-length binary + private val scratchBytes = new Array[Byte](8) + + private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { + val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) + val unscaledLong = decimal.toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + scratchBytes(i) = (unscaledLong >> shift).toByte + i += 1 + shift -= 8 + } + writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + } + } // Optimized for non-nested rows @@ -326,6 +350,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 837ea7695dbb3..c0918a40d136f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -92,6 +92,12 @@ private[sql] object ParquetTestData { required int64 mylong; required float myfloat; required double mydouble; + optional boolean myoptboolean; + optional int32 myoptint; + optional binary myoptstring (UTF8); + optional int64 myoptlong; + optional float myoptfloat; + optional double myoptdouble; } """ @@ -255,6 +261,19 @@ private[sql] object ParquetTestData { record.add(3, i.toLong) record.add(4, i.toFloat + 0.5f) record.add(5, i.toDouble + 0.5d) + if (i % 2 == 0) { + if (i % 3 == 0) { + record.add(6, true) + } else { + record.add(6, false) + } + record.add(7, i) + record.add(8, i.toString) + record.add(9, i.toLong) + record.add(10, i.toFloat + 0.5f) + record.add(11, i.toDouble + 0.5d) + } + writer.write(record) } writer.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e6389cf77a4c9..fa37d1f2ae7e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition @@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._ // Implicits import scala.collection.JavaConversions._ +/** A class representing Parquet info fields we care about, for passing back to Parquet */ +private[parquet] case class ParquetTypeInfo( + primitiveType: ParquetPrimitiveTypeName, + originalType: Option[ParquetOriginalType] = None, + decimalMetadata: Option[DecimalMetadata] = None, + length: Option[Int] = None) + private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binayAsString: Boolean): DataType = + binaryAsString: Boolean): DataType = { + val originalType = parquetType.getOriginalType + val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || - binayAsString) => StringType + if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => + // TODO: for now, our reader only supports decimals that fit in a Long + DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) case _ => sys.error( s"Unsupported parquet datatype $parquetType") } + } /** * Converts a given Parquet `Type` into the corresponding @@ -183,23 +196,40 @@ private[parquet] object ParquetTypesConverter extends Logging { * is not primitive. * * @param ctype The type to convert - * @return The name of the corresponding Parquet primitive type + * @return The name of the corresponding Parquet type properties */ - def fromPrimitiveDataType(ctype: DataType): - Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match { - case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)) - case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None) - case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None) - case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None) - case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None) - case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None) + def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { + case StringType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) + case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) + case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) + case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) + case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) + case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetPrimitiveTypeName.INT32, None) - case ByteType => Some(ParquetPrimitiveTypeName.INT32, None) - case LongType => Some(ParquetPrimitiveTypeName.INT64, None) + case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case DecimalType.Fixed(precision, scale) if precision <= 18 => + // TODO: for now, our writer only supports decimals that fit in a Long + Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, + Some(ParquetOriginalType.DECIMAL), + Some(new DecimalMetadata(precision, scale)), + Some(BYTES_FOR_PRECISION(precision)))) case _ => None } + /** + * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. + */ + private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => + var length = 1 + while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { + length += 1 + } + length + } + /** * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into * the corresponding Parquet `Type`. @@ -247,12 +277,22 @@ private[parquet] object ParquetTypesConverter extends Logging { } else { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } - val primitiveType = fromPrimitiveDataType(ctype) - primitiveType.map { - case (primitiveType, originalType) => - new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) + val typeInfo = fromPrimitiveDataType(ctype) + typeInfo.map { + case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => + val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) + for (len <- length) { + builder.length(len) + } + for (metadata <- decimalMetadata) { + builder.precision(metadata.getPrecision).scale(metadata.getScale) + } + builder.named(name) }.getOrElse { ctype match { + case udt: UserDefinedType[_] => { + fromDataType(udt.sqlType, name, nullable, inArray) + } case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala new file mode 100644 index 0000000000000..9b8c6a56b94b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan + +/** + * A Strategy for planning scans over data sources defined using the sources API. + */ +private[sql] object DataSourceStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, f) => t.buildScan(a, f)) :: Nil + + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, _) => t.buildScan(a)) :: Nil + + case l @ LogicalRelation(t: TableScan) => + execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + + case _ => Nil + } + + protected def pruneFilterProject( + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = filterPredicates.reduceLeftOption(And) + + val pushedFilters = selectFilters(filterPredicates.map { _ transform { + case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. + }}).toArray + + if (projectList.map(_.toAttribute) == projectList && + projectSet.size == projectList.size && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val requestedColumns = + projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + .map(relation.attributeMap) // Match original case of attributes. + .map(_.name) + .toArray + + val scan = + execution.PhysicalRDD( + projectList.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) + filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) + } else { + val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq + val columnNames = requestedColumns.map(_.name).toArray + + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters)) + execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + } + } + + protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + GreaterThanOrEqual(a.name, v) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + LessThanOrEqual(a.name, v) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala new file mode 100644 index 0000000000000..82a2cf8402f8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} + +/** + * Used to link a [[BaseRelation]] in to a logical query plan. + */ +private[sql] case class LogicalRelation(relation: BaseRelation) + extends LeafNode + with MultiInstanceRelation { + + override val output = relation.schema.toAttributes + + // Logical Relations are distinct if they have different output for the sake of transformations. + override def equals(other: Any) = other match { + case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output + case _ => false + } + + override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + case LogicalRelation(otherRelation) => relation == otherRelation + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Allow datasources to provide statistics as well. + sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + ) + + /** Used to lookup original attribute capitalization */ + val attributeMap = AttributeMap(output.map(o => (o, o))) + + def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + + override def simpleString = s"Relation[${output.mkString(",")}] $relation" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala new file mode 100644 index 0000000000000..9168ca2fc6fec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser for foreign DDL commands. + */ +private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { + + def apply(input: String): Option[LogicalPlan] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + logDebug(s"Not recognized as DDL: $x") + None + } + } + + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val ddl: Parser[LogicalPlan] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[LogicalPlan] = + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ provider ~ opts => + CreateTableUsing(tableName, provider, opts) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } +} + +private[sql] case class CreateTableUsing( + tableName: String, + provider: String, + options: Map[String, String]) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val loader = Utils.getContextOrSparkClassLoader + val clazz: Class[_] = try loader.loadClass(provider) catch { + case cnf: java.lang.ClassNotFoundException => + try loader.loadClass(provider + ".DefaultSource") catch { + case cnf: java.lang.ClassNotFoundException => + sys.error(s"Failed to load class for data source: $provider") + } + } + val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + val relation = dataSource.createRelation(sqlContext, options) + + sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) + Seq.empty + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala new file mode 100644 index 0000000000000..e72a2aeb8f310 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +abstract class Filter + +case class EqualTo(attribute: String, value: Any) extends Filter +case class GreaterThan(attribute: String, value: Any) extends Filter +case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter +case class LessThan(attribute: String, value: Any) extends Filter +case class LessThanOrEqual(attribute: String, value: Any) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala new file mode 100644 index 0000000000000..ac3bf9d8e1a21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.sources + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +/** + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to + * pass in the parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait RelationProvider { + /** Returns a new base relation with the given parameters. */ + def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation +} + +/** + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * implementation should inherit from one of the descendant `Scan` classes, which define various + * abstract methods for execution. + * + * BaseRelations must also define a equality function that only returns true when the two + * instances will return the same data. This equality function is used when determining when + * it is safe to substitute cached results for a given relation. + */ +@DeveloperApi +abstract class BaseRelation { + def sqlContext: SQLContext + def schema: StructType +} + +/** + * A BaseRelation that can produce all of its tuples as an RDD of Row objects. + */ +@DeveloperApi +abstract class TableScan extends BaseRelation { + def buildScan(): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. + */ +@DeveloperApi +abstract class PrunedScan extends BaseRelation { + def buildScan(requiredColumns: Array[String]): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns and filter using selected + * predicates before producing an RDD containing all matching tuples as Row objects. + * + * The pushed down filters are currently purely an optimization as they will all be evaluated + * again. This means it is safe to use them with methods that produce false positives such + * as filtering partitions based on a bloom filter. + */ +@DeveloperApi +abstract class PrunedFilteredScan extends BaseRelation { + def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala new file mode 100644 index 0000000000000..8393c510f4f6d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +/** + * A set of APIs for adding data sources to Spark SQL. + */ +package object sources diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 0000000000000..b9569e96c0312 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 609f7db562a31..d4258156f18f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -17,11 +17,18 @@ package org.apache.spark.sql.types.util -import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} +import java.text.SimpleDateFormat import scala.collection.JavaConverters._ +import org.apache.spark.sql._ +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, + MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper} +import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.types.UserDefinedType + protected[sql] object DataTypeConversions { /** @@ -31,19 +38,24 @@ protected[sql] object DataTypeConversions { JDataType.createStructField( scalaStructField.name, asJavaDataType(scalaStructField.dataType), - scalaStructField.nullable) + scalaStructField.nullable, + (new JMetaDataBuilder).withMetadata(scalaStructField.metadata).build()) } /** * Returns the equivalent DataType in Java for the given DataType in Scala. */ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case udtType: UserDefinedType[_] => + UDTWrappers.wrapAsJava(udtType) + case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType - case DecimalType => JDataType.DecimalType + case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale) + case DecimalType.Unlimited => new JDecimalType() case DoubleType => JDataType.DoubleType case FloatType => JDataType.FloatType case ByteType => JDataType.ByteType @@ -68,13 +80,17 @@ protected[sql] object DataTypeConversions { StructField( javaStructField.getName, asScalaDataType(javaStructField.getDataType), - javaStructField.isNullable) + javaStructField.isNullable, + javaStructField.getMetadata) } /** * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => + UDTWrappers.wrapAsScala(udtType) + case stringType: org.apache.spark.sql.api.java.StringType => StringType case binaryType: org.apache.spark.sql.api.java.BinaryType => @@ -86,7 +102,11 @@ protected[sql] object DataTypeConversions { case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => - DecimalType + if (decimalType.isFixed) { + DecimalType(decimalType.getPrecision, decimalType.getScale) + } else { + DecimalType.Unlimited + } case doubleType: org.apache.spark.sql.api.java.DoubleType => DoubleType case floatType: org.apache.spark.sql.api.java.FloatType => @@ -111,10 +131,39 @@ protected[sql] object DataTypeConversions { StructType(structType.getFields.map(asScalaStructField)) } + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any): Any = a match { - case d: java.math.BigDecimal => BigDecimal(d) - case other => other + def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type + case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) + case (other, _) => other } /** Converts Java objects to catalyst rows / types */ diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9435a88009a5f..a04b8060cd658 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -118,7 +118,7 @@ public void applySchemaToJSON() { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("bigInteger", new DecimalType(), true)); fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); fields.add(DataType.createStructField("double", DataType.DoubleType, true)); fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index d04396a5f8ec2..8396a29c61c4c 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -41,7 +41,8 @@ public void createDataTypes() { checkDataType(DataType.BooleanType); checkDataType(DataType.DateType); checkDataType(DataType.TimestampType); - checkDataType(DataType.DecimalType); + checkDataType(new DecimalType()); + checkDataType(new DecimalType(10, 4)); checkDataType(DataType.DoubleType); checkDataType(DataType.FloatType); checkDataType(DataType.ByteType); @@ -59,7 +60,7 @@ public void createDataTypes() { // Simple StructType. List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); @@ -128,7 +129,7 @@ public void illegalArgument() { // StructType try { List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(null); @@ -138,7 +139,7 @@ public void illegalArgument() { } try { List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); DataType.createStructType(simpleFields); diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java new file mode 100644 index 0000000000000..0caa8219a63e9 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.MyDenseVector; +import org.apache.spark.sql.MyLabeledPoint; + +public class JavaUserDefinedTypeSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + @Test + public void useScalaUDT() { + List points = Arrays.asList( + new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})), + new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0}))); + JavaRDD pointsRDD = javaCtx.parallelize(points); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class); + schemaRDD.registerTempTable("points"); + + List actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect(); + List actualLabels = new LinkedList(); + for (Row r : actualLabelRows) { + actualLabels.add(r.getDouble(0)); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualLabels.contains(lp.label())); + } + + List actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect(); + List actualFeatures = new LinkedList(); + for (Row r : actualFeatureRows) { + actualFeatures.add((MyDenseVector)r.get(0)); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualFeatures.contains(lp.features())); + } + + List actual = javaSqlCtx.sql("SELECT label, features FROM points").collect(); + List actualPoints = + new LinkedList(); + for (Row r : actual) { + actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1))); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualPoints.contains(lp)); + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1a5d87d5240e9..765fa82776341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,18 +27,6 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { @@ -243,4 +231,24 @@ class CachedTableSuite extends QueryTest { assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } + + test("Drops temporary table") { + testData.select('key).registerTempTable("t1") + table("t1") + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + } + + test("Drops cached temporary table") { + testData.select('key).registerTempTable("t1") + testData.select('key).registerTempTable("t2") + cacheTable("t1") + + assert(isCached("t1")) + assert(isCached("t2")) + + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + assert(!isCached("t2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 100ecb45e9e88..e9740d913cf57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType - class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -71,7 +69,7 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(DecimalType.Unlimited) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) @@ -79,8 +77,12 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(ArrayType(StringType, false)) checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + val metadata = new MetadataBuilder() + .putString("name", "age") + .build() checkDataTypeJsonRepr( StructType(Seq( StructField("a", IntegerType, nullable = true), - StructField("b", ArrayType(DoubleType), nullable = false)))) + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false, metadata)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 45e58afe9d9a2..e70ad891eea36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.test.TestSQLContext._ class DslQuerySuite extends QueryTest { - import TestData._ + import org.apache.spark.sql.TestData._ test("table scan") { checkAnswer( @@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest { (4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } + + test("udf") { + val foo = (a: Int, b: String) => a.toString + b + + checkAnswer( + // SELECT *, foo(key, value) FROM testData + testData.select(Star(None), foo.call('key, 'value)).limit(3), + (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 042f61f5a4113..3d9f0cbf80fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -78,11 +80,31 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${convertedAnswer.size} ==" +: + prepareAnswer(convertedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1034c2d05f8cf..8a80724c08c7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.joins.BroadcastHashJoin -import org.apache.spark.sql.test._ -import org.scalatest.BeforeAndAfterAll -import java.util.TimeZone /* Implicits */ -import TestSQLContext._ -import TestData._ +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. @@ -281,14 +281,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 3) } - // No support for primitive nulls yet. - ignore("null count") { + test("null count") { checkAnswer( - sql("SELECT a, COUNT(b) FROM testData3"), - Seq((1,0), (2, 1))) + sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), + Seq((1, 0), (2, 1))) checkAnswer( - testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), (2, 1, 2, 2, 1) :: Nil) } @@ -697,6 +696,30 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ("true", "false") :: Nil) } + test("metadata is propagated correctly") { + val person = sql("SELECT * FROM person") + val schema = person.schema + val docKey = "doc" + val docValue = "first name" + val metadata = new MetadataBuilder() + .putString(docKey, docValue) + .build() + val schemaWithMeta = new StructType(Seq( + schema("id"), schema("name").copy(metadata = metadata), schema("age"))) + val personWithMeta = applySchema(person, schemaWithMeta) + def validateMetadata(rdd: SchemaRDD): Unit = { + assert(rdd.schema("name").metadata.getString(docKey) == docValue) + } + personWithMeta.registerTempTable("personWithMeta") + validateMetadata(personWithMeta.select('name)) + validateMetadata(personWithMeta.select("name".attr)) + validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(sql("SELECT * FROM personWithMeta")) + validateMetadata(sql("SELECT id, name FROM personWithMeta")) + validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + } + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( @@ -899,4 +922,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3814 Support Bitwise ~ operator") { checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2) } + + test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { + checkAnswer( + sql( + """SELECT a.key, b.key, c.key + |FROM testData a,testData b,testData c + |where a.key = b.key and a.key = c.key + """.stripMargin), + (1 to 100).map(i => Seq(i, i, i))) + } + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), + (11 to 100).map(i => Seq(i))) + } + + test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { + checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + (1 to 99).map(i => Seq(i))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index bfa9ea416266d..cf3a59e545905 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ @@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) + assert(sql("SELECT * FROM reflectData").collect().head === + Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index c4dd3e860f5fd..92b49e8155900 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -64,11 +64,12 @@ object TestData { BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD binaryData.registerTempTable("binaryData") - // TODO: There is no way to express null primitives as case classes currently... + case class TestData3(a: Int, b: Option[Int]) val testData3 = - logical.LocalRelation('a.int, 'b.int).loadData( - (1, null) :: - (2, 2) :: Nil) + TestSQLContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toSchemaRDD + testData3.registerTempTable("testData3") val emptyTableData = logical.LocalRelation('a.int, 'b.int) @@ -166,4 +167,23 @@ object TestData { // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) withEmptyParts.registerTempTable("withEmptyParts") + + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + val person = TestSQLContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil) + person.registerTempTable("person") + val salary = TestSQLContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil) + salary.registerTempTable("salary") + + case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + val complexData = + TestSQLContext.sparkContext.parallelize( + ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) + :: Nil).toSchemaRDD + complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala new file mode 100644 index 0000000000000..1806a1dd82023 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.beans.{BeanInfo, BeanProperty} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.test.TestSQLContext._ + +@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) +private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } +} + +@BeanInfo +private[sql] case class MyLabeledPoint( + @BeanProperty label: Double, + @BeanProperty features: MyDenseVector) + +private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case features: MyDenseVector => + features.data.toSeq + } + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: Seq[_] => + new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + } + } + + override def userClass = classOf[MyDenseVector] +} + +class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + + + test("register user type: MyDenseVector for MyLabeledPoint") { + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + + val features: RDD[MyDenseVector] = + pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + val featuresArrays: Array[MyDenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index d83f3e23a9468..c9012c9e47cff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.beans.BeanProperty import org.scalatest.FunSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index 8415af41be3af..62fe59dd345d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.api.java -import org.apache.spark.sql.types.util.DataTypeConversions import org.scalatest.FunSuite -import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} -import org.apache.spark.sql.{StructType => SStructType} -import DataTypeConversions._ +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField, StructType => SStructType} +import org.apache.spark.sql.types.util.DataTypeConversions._ class ScalaSideDataTypeConversionSuite extends FunSuite { @@ -40,7 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(org.apache.spark.sql.BooleanType) checkDataType(org.apache.spark.sql.DateType) checkDataType(org.apache.spark.sql.TimestampType) - checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DecimalType.Unlimited) checkDataType(org.apache.spark.sql.DoubleType) checkDataType(org.apache.spark.sql.FloatType) checkDataType(org.apache.spark.sql.ByteType) @@ -60,18 +58,22 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { // Simple StructType. val simpleScalaStructType = SStructType( - SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) :: SStructField("b", org.apache.spark.sql.BooleanType, true) :: SStructField("c", org.apache.spark.sql.LongType, true) :: SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) checkDataType(simpleScalaStructType) // Complex StructType. + val metadata = new MetadataBuilder() + .putString("name", "age") + .build() val complexScalaStructType = SStructType( SStructField("simpleArray", simpleScalaArrayType, true) :: SStructField("simpleMap", simpleScalaMapType, true) :: SStructField("simpleStruct", simpleScalaStructType, true) :: - SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: + SStructField("withMeta", org.apache.spark.sql.DoubleType, false, metadata) :: Nil) checkDataType(complexScalaStructType) // Complex ArrayType. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 9775dd26b7773..15903d07df29a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { - import org.apache.spark.sql.TestData._ - import org.apache.spark.sql.test.TestSQLContext._ + // Make sure the tables are loaded. + TestData test("simple columnar query") { - val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("projection") { - val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().map { @@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) @@ -63,7 +64,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq) - TestSQLContext.cacheTable("repeatedData") + cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -75,7 +76,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) - TestSQLContext.cacheTable("nullableRepeatedData") + cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -87,7 +88,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - TestSQLContext.cacheTable("timestamps") + cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -99,10 +100,17 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq) - TestSQLContext.cacheTable("withEmptyParts") + cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq) } + + test("SPARK-4182 Caching complex types") { + complexData.cache().count() + // Shouldn't throw + complexData.count() + complexData.unpersist() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index ce6184f5d8c9d..f8ca2c773d9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -44,25 +44,35 @@ class JsonSuite extends QueryTest { checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + checkTypePromotion( + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + checkTypePromotion( + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) - + checkTypePromotion( + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(new Timestamp(intNumber.toLong), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) + + val ISO8601Time1 = "1970-01-01T01:00:01.0Z" + checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) + val ISO8601Time2 = "1970-01-01T02:00:01-01:00" + checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { @@ -80,7 +90,7 @@ class JsonSuite extends QueryTest { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType, DecimalType) + checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -91,7 +101,7 @@ class JsonSuite extends QueryTest { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType, StringType) + checkDataType(BooleanType, DecimalType.Unlimited, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -100,7 +110,7 @@ class JsonSuite extends QueryTest { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType, DecimalType) + checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -108,23 +118,23 @@ class JsonSuite extends QueryTest { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType, DecimalType) + checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType, DecimalType) + checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) // DoubleType - checkDataType(DecimalType, DecimalType, DecimalType) - checkDataType(DecimalType, StringType, StringType) - checkDataType(DecimalType, ArrayType(IntegerType), StringType) - checkDataType(DecimalType, StructType(Nil), StringType) + checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DecimalType.Unlimited, StringType, StringType) + checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -178,7 +188,7 @@ class JsonSuite extends QueryTest { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType, + DecimalType.Unlimited, StringType) } @@ -186,7 +196,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -216,7 +226,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: @@ -229,8 +239,8 @@ class JsonSuite extends QueryTest { StructField("field2", StringType, true) :: StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) @@ -288,8 +298,8 @@ class JsonSuite extends QueryTest { // Access a struct and fields inside of it. checkAnswer( sql("select struct, struct.field1, struct.field2 from jsonTable"), - ( - Seq(true, BigDecimal("92233720368547758070")), + Row( + Row(true, BigDecimal("92233720368547758070")), true, BigDecimal("92233720368547758070")) :: Nil ) @@ -331,7 +341,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType, true) :: + StructField("num_num_2", DecimalType.Unlimited, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -479,7 +489,8 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: Nil) + StructField("field", LongType, true) :: Nil), false), true) :: + StructField("array3", ArrayType(StringType, false), true) :: Nil) assert(expectedSchema === jsonSchemaRDD.schema) @@ -488,12 +499,14 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":str}"""), Seq(Seq(214748364700L), Seq(1))) :: Nil + """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: + Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Seq(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( - sql("select array1[0] + 1 from jsonTable"), + sql("select array1[0] + 1 from jsonTable where array1 is not null"), 2 ) } @@ -521,7 +534,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonFile(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -545,13 +558,39 @@ class JsonSuite extends QueryTest { ) } + test("Loading a JSON dataset from a text file with SQL") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTableSQL + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + + checkAnswer( + sql("select * from jsonTableSQL"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + test("Applying schemas") { val file = getTempFilePath("json") val path = file.toString primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index c204162dd2fc1..e5773a55875bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -57,7 +57,9 @@ object TestJsonData { val arrayElementTypeConflict = TestSQLContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], - "array2": [{"field":214748364700}, {"field":1}]}""" :: Nil) + "array2": [{"field":214748364700}, {"field":1}]}""" :: + """{"array3": [{"field":"str"}, {"field":1}]}""" :: + """{"array3": [1, 2, 3]}""" :: Nil) val missingFields = TestSQLContext.sparkContext.parallelize( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 25e41ecf28e2e..3cccafe92d4f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType( case class BinaryData(binaryData: Array[Byte]) +case class NumericData(i: Int, d: Double) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -560,6 +562,103 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(stringResult.size === 1) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") assert(stringResult(0).getInt(1) === 100) + + val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") + assert( + query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val optResult = query7.collect() + assert(optResult.size === 20) + for(i <- 0 until 20) { + if (optResult(i)(7) != i * 2) { + fail(s"optional Int value in result row $i should be ${2*4*i}") + } + } + for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) { + val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") + assert( + query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result8 = query8.collect() + assert(result8.size === 25) + assert(result8(0)(7) === 100) + assert(result8(24)(7) === 148) + val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") + assert( + query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result9 = query9.collect() + assert(result9.size === 25) + if (myval == "myoptint" || myval == "myoptlong") { + assert(result9(0)(7) === 152) + assert(result9(24)(7) === 200) + } else { + assert(result9(0)(7) === 150) + assert(result9(24)(7) === 198) + } + } + val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") + assert( + query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result10 = query10.collect() + assert(result10.size === 1) + assert(result10(0).getString(8) == "100", "stringvalue incorrect") + assert(result10(0).getInt(7) === 100) + val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") + assert( + query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result11 = query11.collect() + assert(result11.size === 7) + for(i <- 0 until 6) { + if (!result11(i).getBoolean(6)) { + fail(s"optional Boolean value in result row $i not true") + } + if (result11(i).getInt(7) != i * 6) { + fail(s"optional Int value in result row $i should be ${6*i}") + } + } + + val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") + assert( + query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result12 = query12.collect() + assert(result12.size === 54) + assert(result12(0).getString(2) == "6") + assert(result12(4).getString(2) == "50") + assert(result12(53).getString(2) == "99") + + val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") + assert( + query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result13 = query13.collect() + assert(result13.size === 53) + assert(result13(0).getString(2) == "6") + assert(result13(4).getString(2) == "51") + assert(result13(52).getString(2) == "99") + + val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") + assert( + query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result14 = query14.collect() + assert(result14.size === 148) + assert(result14(0).getString(2) == "0") + assert(result14(46).getString(2) == "50") + assert(result14(147).getString(2) == "200") + + val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") + assert( + query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result15 = query15.collect() + assert(result15.size === 147) + assert(result15(0).getString(2) == "0") + assert(result15(46).getString(2) == "100") + assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { @@ -812,4 +911,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(a.dataType === b.dataType) } } + + test("read/write fixed-length decimals") { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(precision, scale)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(19, 10)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType.Unlimited) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala new file mode 100644 index 0000000000000..9626252e742e5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfter + +abstract class DataSourceTest extends QueryTest with BeforeAndAfter { + // Case sensitivity is not configurable yet, but we want to test some edge cases. + // TODO: Remove when it is configurable + implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { + @transient + override protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, functionRegistry, caseSensitive = false) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala new file mode 100644 index 0000000000000..8b2f1591d5bf3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -0,0 +1,176 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import scala.language.existentials + +import org.apache.spark.sql._ + +class FilteredScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedFilteredScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + FiltersPushed.list = filters + + val filterFunctions = filters.collect { + case EqualTo("a", v) => (a: Int) => a == v + case LessThan("a", v: Int) => (a: Int) => a < v + case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v + case GreaterThan("a", v: Int) => (a: Int) => a > v + case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + } + + def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + + sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +// A hack for better error messages when filter pushdown fails. +object FiltersPushed { + var list: Seq[Filter] = Nil +} + +class FilteredScanSuite extends DataSourceTest { + + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenFiltered + |USING org.apache.spark.sql.sources.FilteredScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + + def testPushDown(sqlString: String, expectedCount: Int): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala new file mode 100644 index 0000000000000..fee2e22611cdc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class PrunedScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + sqlContext.sparkContext.parallelize(from to to).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +class PrunedScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenPruned + |USING org.apache.spark.sql.sources.PrunedScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT a, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + testPruning("SELECT * FROM oneToTenPruned", "a", "b") + testPruning("SELECT a, b FROM oneToTenPruned", "a", "b") + testPruning("SELECT b, a FROM oneToTenPruned", "b", "a") + testPruning("SELECT b, b FROM oneToTenPruned", "b") + testPruning("SELECT a FROM oneToTenPruned", "a") + testPruning("SELECT b FROM oneToTenPruned", "b") + + def testPruning(sqlString: String, expectedColumns: String*): Unit = { + test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.size != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } + } + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala new file mode 100644 index 0000000000000..b254b0620c779 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -0,0 +1,125 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class DefaultSource extends SimpleScanSource + +class SimpleScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = + StructType(StructField("i", IntegerType, nullable = false) :: Nil) + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) +} + +class TableScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen WHERE i < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT i * 2 FROM oneToTen", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + + test("Caching") { + // Cached Query Execution + cacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen")) + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen")) + checkAnswer( + sql("SELECT i FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) + checkAnswer( + sql("SELECT i FROM oneToTen WHERE i < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT i * 2 FROM oneToTen")) + checkAnswer( + sql("SELECT i * 2 FROM oneToTen"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer( + sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Verify uncaching + uncacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen"), 0) + } + + test("defaultSource") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTenDef"), + (1 to 10).map(Row(_)).toSeq) + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 124fc107cb8aa..8db3010624100 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -70,6 +70,24 @@ org.scalatest scalatest-maven-plugin + + org.codehaus.mojo + build-helper-maven-plugin + + + add-default-sources + generate-sources + + add-source + + + + v${hive.version.short}/src/main/scala + + + + + org.apache.maven.plugins maven-deploy-plugin diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala similarity index 84% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index a5c457c677564..6ed8fd2768f95 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ -import java.util.{ArrayList => JArrayList} - import org.apache.commons.lang.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver @@ -29,11 +27,11 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) - extends Driver with Logging { +private[hive] abstract class AbstractSparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { - private var tableSchema: Schema = _ - private var hiveResponse: Seq[String] = _ + private[hive] var tableSchema: Schema = _ + private[hive] var hiveResponse: Seq[String] = _ override def init(): Unit = { } @@ -74,16 +72,6 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo override def getSchema: Schema = tableSchema - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } - override def destroy() { super.destroy() hiveResponse = null diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3d468d804622c..bd4e99492b395 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} @@ -51,24 +48,12 @@ object HiveThriftServer2 extends Logging { def main(args: Array[String]) { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") - if (!optionsProcessor.process(args)) { System.exit(-1) } - val ss = new SessionState(new HiveConf(classOf[SessionState])) - - // Set all properties specified via command line. - val hiveConf: HiveConf = ss.getConf - hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logDebug(s"HiveConf var: $k=$v") - } - - SessionState.start(ss) - logInfo("Starting SparkContext") SparkSQLEnv.init() - SessionState.start(ss) Runtime.getRuntime.addShutdownHook( new Thread() { @@ -80,7 +65,7 @@ object HiveThriftServer2 extends Logging { try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) - server.init(hiveConf) + server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("HiveThriftServer2 started") } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7ba4564602ecd..2cd02ae9269f5 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -38,6 +38,8 @@ import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging +import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" @@ -116,7 +118,7 @@ private[hive] object SparkSQLCLIDriver { } } - if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + if (!sessionState.isRemoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -258,7 +260,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 42cbf363b274f..499e077d7294a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -17,22 +17,24 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConversions._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory -import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.cli._ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.util.Utils private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -44,19 +46,30 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) + var sparkServiceUGI: UserGroupInformation = null - try { - HiveAuthFactory.loginFromKeytab(hiveConf) - val serverUserName = ShimLoader.getHadoopShims - .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) - setSuperField(this, "serverUserName", serverUserName) - } catch { - case e @ (_: IOException | _: LoginException) => - throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + if (ShimLoader.getHadoopShims.isSecurityEnabled) { + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } } initCompositeService(hiveConf) } + + override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = { + getInfoType match { + case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version) + case _ => super.getInfo(sessionHandle, getInfoType) + } + } } private[thriftserver] trait ReflectedCompositeService { this: AbstractService => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 2136a2ea63543..89732c939b0ec 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import org.apache.hadoop.hive.ql.session.SessionState +import scala.collection.JavaConversions._ -import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.Logging -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.scheduler.StatsReportListener +import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -33,18 +32,18 @@ private[hive] object SparkSQLEnv extends Logging { def init() { if (hiveContext == null) { - sparkContext = new SparkContext(new SparkConf() - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + val sparkConf = new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .set("spark.sql.hive.version", HiveShim.version) + sparkContext = new SparkContext(sparkConf) sparkContext.addSparkListener(new StatsReportListener()) + hiveContext = new HiveContext(sparkContext) - hiveContext = new HiveContext(sparkContext) { - @transient override lazy val sessionState = { - val state = SessionState.get() - setConf(state.getConf.getAllProperties) - state + if (log.isDebugEnabled) { + hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + logDebug(s"HiveConf var: $k=$v") } - @transient override lazy val hiveconf = sessionState.getConf } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index accf61576b804..99c4f46a82b8e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,24 +17,15 @@ package org.apache.spark.sql.hive.thriftserver.server -import java.sql.Timestamp import java.util.{Map => JMap} +import scala.collection.mutable.Map -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map} -import scala.math.{random, round} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} -import org.apache.spark.sql.catalyst.plans.logical.SetCommand -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, ReflectionUtils} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -54,159 +45,9 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { - private var result: SchemaRDD = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType => - val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => - sessionToActivePool(parentSession) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - - val groupId = round(random * 1000000).toString - hiveContext.sparkContext.setJobGroup(groupId, statement) - sessionToActivePool.get(parentSession).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - iter = { - val resultRdd = result.queryExecution.toRdd - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - resultRdd.toLocalIterator - } else { - resultRdd.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - } - } - - handleToOperation.put(operation.getHandle, operation) - operation + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( + hiveContext, sessionToActivePool) + handleToOperation.put(operation.getHandle, operation) + operation } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8a72e9d2aef57..e8ffbc5b954d4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,19 +18,17 @@ package org.apache.spark.sql.hive.thriftserver +import java.io._ + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} -import java.io._ -import java.util.concurrent.atomic.AtomicInteger - import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { @@ -53,23 +51,20 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin.split("\\s+").toSeq ++ extraArgs } - // AtomicInteger is needed because stderr and stdout of the forked process are handled in - // different threads. - val next = new AtomicInteger(0) + var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) val buffer = new ArrayBuffer[String]() + val lock = new Object - def captureOutput(source: String)(line: String) { + def captureOutput(source: String)(line: String): Unit = lock.synchronized { buffer += s"$source> $line" - // If we haven't found all expected answers... - if (next.get() < expectedAnswers.size) { - // If another expected answer is found... - if (line.startsWith(expectedAnswers(next.get()))) { - // If all expected answers have been found... - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) - } + // If we haven't found all expected answers and another expected answer comes up... + if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + next += 1 + // If all expected answers have been found... + if (next == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) } } } @@ -88,8 +83,8 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |======================= |Spark SQL CLI command line: ${command.mkString(" ")} | - |Executed query ${next.get()} "${queries(next.get())}", - |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout. + |Executed query $next "${queries(next)}", + |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | |${buffer.mkString("\n")} |=========================== diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index e3b4e45a3d68c..bba29b2bdca4d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -30,42 +30,95 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift._ +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket import org.scalatest.FunSuite import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.hive.HiveShim /** * Tests for the HiveThriftServer2 using JDBC. + * + * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be + * rebuilt after changing HiveThriftServer2 related code. */ class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + def randomListeningPort = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val port = randomListeningPort + + startThriftServer(port, serverStartTimeout) { + val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/" + val user = System.getProperty("user.name") + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() + + try { + f(statement) + } finally { + statement.close() + connection.close() + } + } + } + + def withCLIServiceClient( + serverStartTimeout: FiniteDuration = 1.minute)( + f: ThriftCLIServiceClient => Unit) { + val port = randomListeningPort + + startThriftServer(port) { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", port) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + + try { + f(client) + } finally { + transport.close() + } + } + } + + def startThriftServer( + port: Int, + serverStartTimeout: FiniteDuration = 1.minute)( + f: => Unit) { val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) val warehousePath = getTempFilePath("warehouse") val metastorePath = getTempFilePath("metastore") val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - val listeningHost = "localhost" - val listeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - val command = s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"} + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port """.stripMargin.split("\\s+").toSeq val serverRunning = Promise[Unit]() @@ -92,31 +145,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } } - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger( + val env = Seq( + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + "SPARK_TESTING" -> "0", + // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read + // proper version information from the jar manifest. + "SPARK_PREPEND_CLASSES" -> "") + + Process(command, None, env: _*).run(ProcessLogger( captureThriftServerOutput("stdout"), captureThriftServerOutput("stderr"))) - val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" - val user = System.getProperty("user.name") - try { - Await.result(serverRunning.future, timeout) - - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } + Await.result(serverRunning.future, serverStartTimeout) + f } catch { case cause: Exception => cause match { case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $timeout", cause) + logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) case _ => } logError( @@ -125,8 +172,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |HiveThriftServer2Suite failure output |===================================== |HiveThriftServer2 command line: ${command.mkString(" ")} - |JDBC URI: $jdbcUri - |User: $user + |Binding port: $port + |System user: ${System.getProperty("user.name")} | |${buffer.mkString("\n")} |========================================= @@ -146,14 +193,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("Test JDBC query execution") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - val queries = Seq( - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test", - "CACHE TABLE test") + val queries = + s"""SET spark.sql.shuffle.partitions=3; + |CREATE TABLE test(key INT, val STRING); + |LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test; + |CACHE TABLE test; + """.stripMargin.split(";").map(_.trim).filter(_.nonEmpty) queries.foreach(statement.execute) @@ -166,7 +215,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("SPARK-3004 regression: result set containing NULL") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource( "data/files/small_kv_with_null.txt") @@ -189,4 +238,56 @@ class HiveThriftServer2Suite extends FunSuite with Logging { assert(!resultSet.next()) } } + + test("GetInfo Thrift API") { + withCLIServiceClient() { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("Checks Hive version") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) == key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala new file mode 100644 index 0000000000000..aa2e3cab72bb9 --- /dev/null +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.sql.Timestamp +import java.util.{ArrayList => JArrayList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.math._ + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.ExecuteStatementOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.plans.logical.SetCommand +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} + +/** + * A compatibility layer for interacting with Hive version 0.12.0. + */ +private[thriftserver] object HiveThriftServerShim { + val version = "0.12.0" + + def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { + val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) + setSuperField(sparkCliService, "serverUserName", serverUserName) + } +} + +private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) + extends AbstractSparkSQLDriver(_context) { + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } +} + +private[hive] class SparkExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String])( + hiveContext: HiveContext, + sessionToActivePool: SMap[HiveSession, String]) + extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { + + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logDebug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + if (sparkRow.isNullAt(curCol)) { + addNullColumnValue(sparkRow, row, curCol) + } else { + addNonNullColumnValue(sparkRow, row, curCol) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(from(ordinal).asInstanceOf[String]) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) + case DecimalType() => + val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) + case ShortType => + to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) + case TimestampType => + to.addColumnValue( + ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) + to.addColumnValue(ColumnValue.stringValue(hiveString)) + } + } + + def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(null) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(null)) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(null)) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(null)) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(null)) + case DecimalType() => + to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) + case LongType => + to.addColumnValue(ColumnValue.longValue(null)) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(null)) + case ShortType => + to.addColumnValue(ColumnValue.shortValue(null)) + case TimestampType => + to.addColumnValue(ColumnValue.timestampValue(null)) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + to.addColumnValue(ColumnValue.stringValue(null: String)) + } + } + + def getResultSetSchema: TableSchema = { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logInfo(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.sql(statement) + logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => + sessionToActivePool(parentSession) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + sessionToActivePool.get(parentSession).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } + iter = { + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + result.toLocalIterator + } else { + result.collect().iterator + } + } + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + setState(OperationState.ERROR) + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } +} diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala new file mode 100644 index 0000000000000..a642478d08857 --- /dev/null +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.security.PrivilegedExceptionAction +import java.sql.Timestamp +import java.util.concurrent.Future +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.math._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.ExecuteStatementOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} + +/** + * A compatibility layer for interacting with Hive version 0.12.0. + */ +private[thriftserver] object HiveThriftServerShim { + val version = "0.13.1" + + def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { + setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) + } +} + +private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) + extends AbstractSparkSQLDriver(_context) { + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } +} + +private[hive] class SparkExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + runInBackground: Boolean = true)( + hiveContext: HiveContext, + sessionToActivePool: SMap[HiveSession, String]) extends ExecuteStatementOperation( + parentSession, statement, confOverlay, runInBackground) with Logging { + + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + private def runInternal(cmd: String) = { + try { + result = hiveContext.sql(cmd) + logDebug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = { + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + result.toLocalIterator + } else { + result.collect().iterator + } + } + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + setState(OperationState.ERROR) + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + } + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logDebug("CLOSING") + } + + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to += from.get(ordinal).asInstanceOf[String] + case IntegerType => + to += from.getInt(ordinal) + case BooleanType => + to += from.getBoolean(ordinal) + case DoubleType => + to += from.getDouble(ordinal) + case FloatType => + to += from.getFloat(ordinal) + case DecimalType() => + to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + case LongType => + to += from.getLong(ordinal) + case ByteType => + to += from.getByte(ordinal) + case ShortType => + to += from.getShort(ordinal) + case TimestampType => + to += from.get(ordinal).asInstanceOf[Timestamp] + case BinaryType => + to += from.get(ordinal).asInstanceOf[String] + case _: ArrayType => + to += from.get(ordinal).asInstanceOf[String] + case _: StructType => + to += from.get(ordinal).asInstanceOf[String] + case _: MapType => + to += from.get(ordinal).asInstanceOf[String] + } + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + validateDefaultFetchOrientation(order) + assertState(OperationState.FINISHED) + setHasResultSet(true) + val reultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + if (!iter.hasNext) { + reultRowSet + } else { + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + var curRow = 0 + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = ArrayBuffer[Any]() + var curCol = 0 + while (curCol < sparkRow.length) { + if (sparkRow.isNullAt(curCol)) { + row += null + } else { + addNonNullColumnValue(sparkRow, row, curCol) + } + curCol += 1 + } + reultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + curRow += 1 + } + reultRowSet + } + } + + def getResultSetSchema: TableSchema = { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + private def getConfigForOperation: HiveConf = { + var sqlOperationConf: HiveConf = getParentSession.getHiveConf + if (!getConfOverlay.isEmpty || shouldRunAsync) { + sqlOperationConf = new HiveConf(sqlOperationConf) + import scala.collection.JavaConversions._ + for (confEntry <- getConfOverlay.entrySet) { + try { + sqlOperationConf.verifyAndSet(confEntry.getKey, confEntry.getValue) + } + catch { case e: IllegalArgumentException => + throw new HiveSQLException("Error applying statement specific settings", e) + } + } + } + sqlOperationConf + } + + def run(): Unit = { + logInfo(s"Running query '$statement'") + val opConfig: HiveConf = getConfigForOperation + setState(OperationState.RUNNING) + setHasResultSet(true) + + if (!shouldRunAsync) { + runInternal(statement) + setState(OperationState.FINISHED) + } else { + val parentSessionState = SessionState.get + val sessionHive: Hive = Hive.get + val currentUGI: UserGroupInformation = ShimLoader.getHadoopShims.getUGIForConf(opConfig) + + val backgroundOperation: Runnable = new Runnable { + def run() { + val doAsAction: PrivilegedExceptionAction[AnyRef] = + new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + Hive.set(sessionHive) + SessionState.setCurrentSessionState(parentSessionState) + try { + runInternal(statement) + } + catch { case e: HiveSQLException => + setOperationException(e) + logError("Error running hive query: ", e) + } + null + } + } + try { + ShimLoader.getHadoopShims.doAs(currentUGI, doAsAction) + } + catch { case e: Exception => + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + currentUGI.getShortUserName, e) + } + setState(OperationState.FINISHED) + } + } + + try { + val backgroundHandle: Future[_] = getParentSession.getSessionManager. + submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + } + } +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index db01363b4d629..67e36a951e506 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -65,6 +65,10 @@ commons-logging commons-logging + + com.esotericsoftware.kryo + kryo + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad4091d48a89..e88afaaf001c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,6 +21,10 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -46,6 +50,7 @@ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand +import org.apache.spark.sql.sources.DataSourceStrategy /** * DEPRECATED: Use HiveContext instead. @@ -85,7 +90,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * SerDe. */ private[spark] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -95,7 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (dialect == "sql") { super.sql(sqlText) } else if (dialect == "hiveql") { - new SchemaRDD(this, HiveQl.parseSql(sqlText)) + new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) } else { sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") } @@ -224,21 +229,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } /** - * SQLConf and HiveConf contracts: when the hive session is first initialized, params in - * HiveConf will get picked up by the SQLConf. Additionally, any properties set by - * set() or a SET command inside sql() will be set in the SQLConf *as well as* - * in the HiveConf. + * SQLConf and HiveConf contracts: + * + * 1. reuse existing started SessionState if any + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. */ - @transient lazy val hiveconf = new HiveConf(classOf[SessionState]) - @transient protected[hive] lazy val sessionState = { - val ss = new SessionState(hiveconf) - setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. - SessionState.start(ss) - ss.err = new PrintStream(outputBuffer, true, "UTF-8") - ss.out = new PrintStream(outputBuffer, true, "UTF-8") - - ss - } + @transient protected[hive] lazy val (hiveconf, sessionState) = + Option(SessionState.get()) + .orElse { + val newState = new SessionState(new HiveConf(classOf[SessionState])) + // Only starts newly created `SessionState` instance. Any existing `SessionState` instance + // returned by `SessionState.get()` must be the most recently started one. + SessionState.start(newState) + Some(newState) + } + .map { state => + setConf(state.getConf.getAllProperties) + if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8") + if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8") + (state.getConf, state) + } + .get override def setConf(key: String, value: String): Unit = { super.setConf(key, value) @@ -288,6 +301,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hiveconf) + // Makes sure the session represented by the `sessionState` field is activated. This implies + // Spark SQL Hive support uses a single `SessionState` for all Hive operations and breaks + // session isolation under multi-user scenarios (i.e. HiveThriftServer2). + // TODO Fix session isolation + if (SessionState.get() != sessionState) { + SessionState.start(sessionState) + } + proc match { case driver: Driver => val results = HiveShim.createDriverResultsArray @@ -302,7 +323,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { driver.close() HiveShim.processResults(results) case _ => - sessionState.out.println(tokens(0) + " " + cmd_1) + if (sessionState.out != null) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } Seq(proc.run(cmd_1).getResponseCode.toString) } } catch { @@ -325,7 +348,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self - override val strategies: Seq[Strategy] = Seq( + override val strategies: Seq[Strategy] = extraStrategies ++ Seq( + DataSourceStrategy, CommandStrategy(self), HiveCommandStrategy(self), TakeOrdered, @@ -350,11 +374,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) - protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, DateType, TimestampType, BinaryType) + ShortType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -372,6 +394,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString + HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -390,6 +414,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -406,7 +431,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { command.executeCollect().map(_.head.toString) case other => - val result: Seq[Seq[Any]] = toRdd.collect().toSeq + val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c6103a124df59..7e76aff642bb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._ @@ -28,6 +28,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -38,7 +39,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -54,8 +55,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -87,10 +88,14 @@ private[hive] trait HiveInspectors { * @return convert the data into catalyst type */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { + case _ if data == null => null case hvoi: HiveVarcharObjectInspector => if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data) + // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object + // if next timestamp is null, so Timestamp object is cloned + case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) case li: ListObjectInspector => Option(li.getList(data)) @@ -110,6 +115,47 @@ private[hive] trait HiveInspectors { unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) } + + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct + } + + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) + + case moi: MapObjectInspector => + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + + case _ => + identity[Any] + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped @@ -134,8 +180,9 @@ private[hive] trait HiveInspectors { case l: Short => l: java.lang.Short case l: Byte => l: java.lang.Byte case b: BigDecimal => HiveShim.createDecimal(b.underlying()) + case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying()) case b: Array[Byte] => b - case d: java.sql.Date => d + case d: java.sql.Date => d case t: java.sql.Timestamp => t } case x: StructObjectInspector => @@ -197,51 +244,60 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) } def toInspector(expr: Expression): ObjectInspector = expr match { - case Literal(value: String, StringType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Int, IntegerType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Double, DoubleType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Boolean, BooleanType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Long, LongType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Float, FloatType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Short, ShortType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Byte, ByteType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: Array[Byte], BinaryType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: java.sql.Date, DateType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: java.sql.Timestamp, TimestampType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: BigDecimal, DecimalType) => - HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value, StringType) => + HiveShim.getStringWritableConstantObjectInspector(value) + case Literal(value, IntegerType) => + HiveShim.getIntWritableConstantObjectInspector(value) + case Literal(value, DoubleType) => + HiveShim.getDoubleWritableConstantObjectInspector(value) + case Literal(value, BooleanType) => + HiveShim.getBooleanWritableConstantObjectInspector(value) + case Literal(value, LongType) => + HiveShim.getLongWritableConstantObjectInspector(value) + case Literal(value, FloatType) => + HiveShim.getFloatWritableConstantObjectInspector(value) + case Literal(value, ShortType) => + HiveShim.getShortWritableConstantObjectInspector(value) + case Literal(value, ByteType) => + HiveShim.getByteWritableConstantObjectInspector(value) + case Literal(value, BinaryType) => + HiveShim.getBinaryWritableConstantObjectInspector(value) + case Literal(value, DateType) => + HiveShim.getDateWritableConstantObjectInspector(value) + case Literal(value, TimestampType) => + HiveShim.getTimestampWritableConstantObjectInspector(value) + case Literal(value, DecimalType()) => + HiveShim.getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => HiveShim.getPrimitiveNullWritableConstantObjectInspector - case Literal(value: Seq[_], ArrayType(dt, _)) => + case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) - val list = new java.util.ArrayList[Object]() - value.foreach(v => list.add(wrap(v, listObjectInspector))) - ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) - case Literal(map: Map[_, _], MapType(keyType, valueType, _)) => - val value = new java.util.HashMap[Object, Object]() + if (value == null) { + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) + } else { + val list = new java.util.ArrayList[Object]() + value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) + } + case Literal(value, MapType(keyType, valueType, _)) => val keyOI = toInspector(keyType) val valueOI = toInspector(valueType) - map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))) - ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value) - case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") + if (value == null) { + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null) + } else { + val map = new java.util.HashMap[Object, Object]() + value.asInstanceOf[Map[_, _]].foreach (entry => { + map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + }) + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) + } case _ => toInspector(expr.dataType) } @@ -274,8 +330,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType + case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -304,7 +360,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo + case d: DecimalType => HiveShim.decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2dd2c882a8420..9ae019842217d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import scala.util.matching.Regex import scala.util.parsing.combinator.RegexParsers import org.apache.hadoop.util.ReflectionUtils @@ -56,6 +57,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + def tableExists(db: Option[String], tableName: String): Boolean = { + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + client.getTable(databaseName, tblName, false) != null + } + def lookupRelation( db: Option[String], tableName: String, @@ -321,11 +328,18 @@ object HiveMetastoreTypes extends RegexParsers { "bigint" ^^^ LongType | "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | - HiveShim.metastoreDecimal ^^^ DecimalType | + fixedDecimalType | // Hive 0.13+ decimal with precision/scale + "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType + protected lazy val fixedDecimalType: Parser[DataType] = + ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "array" ~> "<" ~> dataType <~ ">" ^^ { case tpe => ArrayType(tpe) @@ -373,9 +387,10 @@ object HiveMetastoreTypes extends RegexParsers { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case DecimalType => "decimal" + case d: DecimalType => HiveShim.decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" + case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) } } @@ -441,7 +456,7 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 9d9d68affa54b..74f68d0f95317 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -325,7 +326,11 @@ private[hive] object HiveQl { } protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", Nil) => DecimalType + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -942,8 +947,12 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), BinaryType) case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType) + Cast(nodeToExpr(arg), DecimalType.Unlimited) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => @@ -985,15 +994,20 @@ private[hive] object HiveQl { In(nodeToExpr(value), list.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(BETWEEN(), Nil) :: - Token("KW_FALSE", Nil) :: + kw :: target :: minValue :: maxValue :: Nil) => val targetExpression = nodeToExpr(target) - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + val betweenExpr = + And( + GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), + LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + kw match { + case Token("KW_FALSE", Nil) => betweenExpr + case Token("KW_TRUE", Nil) => Not(betweenExpr) + } /* Boolean Logic */ case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) @@ -1058,7 +1072,7 @@ private[hive] object HiveQl { } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(BigDecimal(strVal)) + v = Literal(Decimal(strVal)) } else { v = Literal(ast.getText.toDouble, DoubleType) v = Literal(ast.getText.toLong, LongType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index e59d4d536a0af..989740c8d43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} import scala.collection.JavaConversions._ @@ -206,6 +206,8 @@ private[hive] trait HiveStrategies { case hive.AddJar(path) => execution.AddJar(path) :: Nil + case hive.AddFile(path) => execution.AddFile(path) :: Nil + case hive.AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil case describe: logical.DescribeCommand => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index e49f0957d188a..f60bc3788e3e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -290,6 +290,21 @@ private[hive] object HadoopTableReader extends HiveInspectors { (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi: HiveVarcharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + case oi: HiveDecimalObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) + case oi: TimestampObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + case oi: DateObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value)) + case oi: BinaryObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, oi.getPrimitiveJavaObject(value)) case oi => (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 2fce414734579..3d24d87bc3d38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -71,7 +71,17 @@ case class CreateTableAsSelect( // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + if (sc.catalog.tableExists(Some(database), tableName)) { + if (allowExisting) { + // table already exists, will do nothing, to keep consistent with Hive + } else { + throw + new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") + } + } else { + sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + } + Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 79234f8a66f05..81390f626726c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution +import java.util + import scala.collection.JavaConversions._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.HiveVarchar import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils @@ -35,6 +37,7 @@ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} @@ -51,7 +54,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode with Command { + extends UnaryNode with Command with HiveInspectors { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -67,46 +70,6 @@ case class InsertIntoHiveTable( def output = child.output - /** - * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. - */ - protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { - case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) - - case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying()) - - case soi: StandardStructObjectInspector => - val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) - (o: Any) => { - val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) - } - struct - } - - case loi: ListObjectInspector => - val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - - case moi: MapObjectInspector => - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) - val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) - (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) - }) - - case _ => - identity[Any] - } - def saveAsHiveFile( rdd: RDD[Row], valueClass: Class[_], @@ -242,6 +205,13 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { + + // loadPartition call orders directories created on the iteration order of the this map + val orderedPartitionSpec = new util.LinkedHashMap[String,String]() + table.hiveQlTable.getPartCols().foreach{ + entry=> + orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -253,7 +223,7 @@ case class InsertIntoHiveTable( db.loadDynamicPartitions( outputPath, qualifiedTableName, - partitionSpec, + orderedPartitionSpec, overwrite, numDynamicPartitions, holdDDLTime, @@ -263,7 +233,7 @@ case class InsertIntoHiveTable( db.loadPartition( outputPath, qualifiedTableName, - partitionSpec, + orderedPartitionSpec, overwrite, holdDDLTime, inheritTableSpecs, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 0fc674af31885..903075edf7e04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -76,3 +76,19 @@ case class AddJar(path: String) extends LeafNode with Command { Seq.empty[Row] } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class AddFile(path: String) extends LeafNode with Command { + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + override def output = Seq.empty + + override protected lazy val sideEffectResult: Seq[Row] = { + hiveContext.runSqlHive(s"ADD FILE $path") + hiveContext.sparkContext.addFile(path) + Seq.empty[Row] + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index aff4ddce92272..86f7eea5dfd69 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis @@ -134,11 +135,19 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ } } +// Adapter from Catalyst ExpressionResult to Hive DeferredObject +private[hive] class DeferredObjectAdapter(oi: ObjectInspector) + extends DeferredObject with HiveInspectors { + private var func: () => Any = _ + def set(func: () => Any) { + this.func = func + } + override def prepare(i: Int) = {} + override def get(): AnyRef = wrap(func(), oi) +} + private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) extends HiveUdf with HiveInspectors { - - import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ - type UDFType = GenericUDF @transient @@ -161,16 +170,6 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq protected lazy val deferedObjects = argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] - // Adapter from Catalyst ExpressionResult to Hive DeferredObject - class DeferredObjectAdapter(oi: ObjectInspector) extends DeferredObject { - private var func: () => Any = _ - def set(func: () => Any) { - this.func = func - } - override def prepare(i: Int) = {} - override def get(): AnyRef = wrap(func(), oi) - } - lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: Row): Any = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index bf2ce9df67c58..cc8bb3e172c6e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} diff --git a/sql/hive/src/test/resources/data/files/issue-4077-data.txt b/sql/hive/src/test/resources/data/files/issue-4077-data.txt new file mode 100644 index 0000000000000..18067b0a64c9c --- /dev/null +++ b/sql/hive/src/test/resources/data/files/issue-4077-data.txt @@ -0,0 +1,2 @@ +2014-12-11 00:00:00,1 +2014-12-11astring00:00:00,2 diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 new file mode 100644 index 0000000000000..7c41615f8c184 --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 @@ -0,0 +1 @@ +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 b/sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f b/sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d b/sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a b/sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a b/sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c new file mode 100644 index 0000000000000..2cf0d9d61882e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -0,0 +1 @@ +There is no documentation for function 'if' diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a new file mode 100644 index 0000000000000..2cf0d9d61882e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -0,0 +1 @@ +There is no documentation for function 'if' diff --git a/sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 b/sql/hive/src/test/resources/golden/udf_if-2-f2b010128e922d0096a65ddd9ae1d0b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f b/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f new file mode 100644 index 0000000000000..a29e96cbd1db7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-3-20206f17367ff284d67044abd745ce9f @@ -0,0 +1 @@ +1 1 1 1 NULL 2 diff --git a/sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca b/sql/hive/src/test/resources/golden/udf_if-4-174dae8a1eb4cad6ccf6f67203de71ca new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc b/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc new file mode 100644 index 0000000000000..f0669b86989d0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_if-5-a7db13aec05c97792f9331d63709d8cc @@ -0,0 +1 @@ +128 1.1 ABC 12.3 diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 b/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 b/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 b/sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 b/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 new file mode 100644 index 0000000000000..de25f51b5b56d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 @@ -0,0 +1 @@ +{"foo":1,"bar":2} 1 diff --git a/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 b/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb b/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 b/sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 b/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 new file mode 100644 index 0000000000000..ff1a28fa47f18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 @@ -0,0 +1 @@ +{"col1":1} {"col1":1,"col2":"a"} 1 a diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 4a64b5f5eb1b4..86535f8dd4f58 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.hive import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.{DataType, StructType} +import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.test.ExamplePointUDT class HiveMetastoreCatalogSuite extends FunSuite { @@ -29,4 +30,10 @@ class HiveMetastoreCatalogSuite extends FunSuite { val datatype = HiveMetastoreTypes.toDataType(metastr) assert(datatype.isInstanceOf[StructType]) } + + test("udt to metastore type conversion") { + val udt = new ExamplePointUDT + assert(HiveMetastoreTypes.toMetastoreType(udt) === + HiveMetastoreTypes.toMetastoreType(udt.sqlType)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 18dc937dd2b27..5dbfb923139fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql._ +import java.io.File + +import com.google.common.io.Files +import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive /* Implicits */ @@ -91,4 +93,32 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE hiveTableWithMapValue") } + + test("SPARK-4203:random partition directory order") { + createTable[TestData]("tmp_table") + val tmpDir = Files.createTempDir() + sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") + sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + def listFolders(path: File, acc: List[String]): List[List[String]] = { + val dir = path.listFiles() + val folders = dir.filter(_.isDirectory).toList + if (folders.isEmpty) { + List(acc.reverse) + } else { + folders.flatMap(x => listFolders(x, x.getName :: acc)) + } + } + val expected = List( + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil + ) + assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + sql("DROP TABLE table_with_partition") + sql("DROP TABLE tmp_table") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index ffe1f0b90fcd0..684d22807c0c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,23 +17,73 @@ package org.apache.spark.sql.hive.execution +import java.io.File +import java.util.{Locale, TimeZone} + +import org.scalatest.BeforeAndAfter + import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkException +import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest { +class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + override def beforeAll() { + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + } + + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + } + + createQueryTest("constant null testing", + """SELECT + |IF(FALSE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL1, + |IF(TRUE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL2, + |IF(FALSE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL3, + |IF(TRUE, CAST(NULL AS INT), CAST(1 AS INT)) AS COL4, + |IF(FALSE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL5, + |IF(TRUE, CAST(NULL AS DOUBLE), CAST(1 AS DOUBLE)) AS COL6, + |IF(FALSE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL7, + |IF(TRUE, CAST(NULL AS BOOLEAN), CAST(1 AS BOOLEAN)) AS COL8, + |IF(FALSE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL9, + |IF(TRUE, CAST(NULL AS BIGINT), CAST(1 AS BIGINT)) AS COL10, + |IF(FALSE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL11, + |IF(TRUE, CAST(NULL AS FLOAT), CAST(1 AS FLOAT)) AS COL12, + |IF(FALSE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL13, + |IF(TRUE, CAST(NULL AS SMALLINT), CAST(1 AS SMALLINT)) AS COL14, + |IF(FALSE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL15, + |IF(TRUE, CAST(NULL AS TINYINT), CAST(1 AS TINYINT)) AS COL16, + |IF(FALSE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL17, + |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, + |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, + |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, + |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |FROM src LIMIT 1""".stripMargin) + createQueryTest("constant array", """ |SELECT sort_array( @@ -569,7 +619,7 @@ class HiveQuerySuite extends HiveComparisonTest { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatability + // Now only verify 0.12.0, and ignore other versions due to binary compatibility // current TestSerDe.jar is from 0.12.0 if (HiveShim.version == "0.12.0") { sql(s"ADD JAR $testJar") @@ -581,6 +631,17 @@ class HiveQuerySuite extends HiveComparisonTest { sql("DROP TABLE alter1") } + test("ADD FILE command") { + val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile + sql(s"ADD FILE $testFile") + + val checkAddFileRDD = sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => + Iterator.single(new File(SparkFiles.get("v1.txt")).canRead) + } + + assert(checkAddFileRDD.first()) + } + case class LogEntry(filename: String, message: String) case class LogFile(name: String) @@ -756,7 +817,7 @@ class HiveQuerySuite extends HiveComparisonTest { }.toSet clear() - // "set" itself returns all config variables currently specified in SQLConf. + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -765,44 +826,19 @@ class HiveQuerySuite extends HiveComparisonTest { } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) - } + assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - - // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) - } - - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) - } - - // Assert that sql() should have the same effects as sql() by repeating the above using sql(). - clear() - assert(sql("SET").collect().size == 0) - - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) - } - - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) - } - - sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + collectResults(sql("SET -v")) } + // "SET key" assertResult(Set(testKey -> testVal)) { collectResults(sql(s"SET $testKey")) } @@ -816,7 +852,7 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("select from thrift based table", "SELECT * from src_thrift") - + // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 2f3db95882093..54c0f017d4cb6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.{Row, SchemaRDD} + +import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { @@ -47,4 +50,23 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } + + test("Spark-4077: timestamp query for null value") { + TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") + TestHive.sql( + """ + CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT) + ROW FORMAT DELIMITED + FIELDS TERMINATED BY ',' + LINES TERMINATED BY '\n' + """.stripMargin) + val location = + Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() + + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + TestHive.sql("DROP TABLE timestamp_query_null") + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4f96a327ee2c7..e9b1943ff8db7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -56,7 +56,7 @@ class SQLQuerySuite extends QueryTest { sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect - // expect the string => integer for field key cause the table ctas4 already existed. + // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect @@ -78,9 +78,14 @@ class SQLQuerySuite extends QueryTest { SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) + intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { + sql( + """CREATE TABLE ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - sql("SELECT CAST(key AS int) k, value FROM src ORDER BY k, value").collect().toSeq) + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) checkExistence(sql("DESC EXTENDED ctas2"), true, "name:key", "type:string", "name:value", "ctas2", @@ -158,4 +163,9 @@ class SQLQuerySuite extends QueryTest { sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), sql("SELECT 1 FROM src").collect().toSeq) } + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + } } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index afc252ac27987..8ba25f889d176 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -30,21 +30,24 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions +import org.apache.spark.sql.catalyst.types.DecimalType + /** * A compatibility layer for interacting with Hive version 0.12.0. */ private[hive] object HiveShim { val version = "0.12.0" - val metastoreDecimal = "decimal" def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -54,54 +57,74 @@ private[hive] object HiveShim { new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) } - def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, new hadoopIo.Text(value)) + PrimitiveCategory.STRING, + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) - def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, new hadoopIo.IntWritable(value)) + PrimitiveCategory.INT, + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) - def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value)) + PrimitiveCategory.DOUBLE, + if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])) - def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value)) + PrimitiveCategory.BOOLEAN, + if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])) - def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, new hadoopIo.LongWritable(value)) + PrimitiveCategory.LONG, + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) - def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value)) + PrimitiveCategory.FLOAT, + if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])) - def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value)) + PrimitiveCategory.SHORT, + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) - def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value)) + PrimitiveCategory.BYTE, + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) - def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value)) + PrimitiveCategory.BINARY, + if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, new hiveIo.DateWritable(value)) + PrimitiveCategory.DATE, + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value)) - - def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + PrimitiveCategory.TIMESTAMP, + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + }) + + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DECIMAL, - new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + if (value == null) { + null + } else { + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + }) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( @@ -149,6 +172,19 @@ private[hive] object HiveShim { def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) } + + def decimalMetastoreString(decimalType: DecimalType): String = "decimal" + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = + TypeInfoFactory.decimalTypeInfo + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + DecimalType.Unlimited + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + } } class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 42cd65b2518c9..e4aee57f0ad9f 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -29,15 +29,15 @@ import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -47,11 +47,6 @@ import scala.language.implicitConversions */ private[hive] object HiveShim { val version = "0.13.1" - /* - * TODO: hive-0.13 support DECIMAL(precision, scale), DECIMAL in hive-0.12 is actually DECIMAL(38,unbounded) - * Full support of new decimal feature need to be fixed in seperate PR. - */ - val metastoreDecimal = "decimal\\((\\d+),(\\d+)\\)".r def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -61,54 +56,86 @@ private[hive] object HiveShim { new TableDesc(inputFormatClass, outputFormatClass, properties) } - def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, new hadoopIo.Text(value)) + TypeInfoFactory.stringTypeInfo, + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) - def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, new hadoopIo.IntWritable(value)) + TypeInfoFactory.intTypeInfo, + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) - def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, new hiveIo.DoubleWritable(value)) + TypeInfoFactory.doubleTypeInfo, if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, new hadoopIo.BooleanWritable(value)) + TypeInfoFactory.booleanTypeInfo, if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, new hadoopIo.LongWritable(value)) + TypeInfoFactory.longTypeInfo, + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) - def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, new hadoopIo.FloatWritable(value)) + TypeInfoFactory.floatTypeInfo, if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + }) - def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, new hiveIo.ShortWritable(value)) + TypeInfoFactory.shortTypeInfo, + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) - def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, new hiveIo.ByteWritable(value)) + TypeInfoFactory.byteTypeInfo, + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) - def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, new hadoopIo.BytesWritable(value)) + TypeInfoFactory.binaryTypeInfo, if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + }) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, new hiveIo.DateWritable(value)) + TypeInfoFactory.dateTypeInfo, + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) - def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, new hiveIo.TimestampWritable(value)) + TypeInfoFactory.timestampTypeInfo, if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + }) - def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.decimalTypeInfo, - new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + }) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( @@ -197,6 +224,30 @@ private[hive] object HiveShim { f.setDestTableId(w.destTableId) f } + + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + private val UNLIMITED_DECIMAL_PRECISION = 38 + private val UNLIMITED_DECIMAL_SCALE = 18 + + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" + } + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) + } + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } } /* diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 23d6d1c5e50fa..54b219711efb9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -436,10 +436,10 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. + * + * @throws SparkException if the context has already been started or stopped. */ def start(): Unit = synchronized { - // Throw exception if the context has already been started once - // or if a stopped context is being started again if (state == Started) { throw new SparkException("StreamingContext has already been started") } @@ -472,8 +472,10 @@ class StreamingContext private[streaming] ( /** * Stop the execution of the streams immediately (does not wait for all received data * to be processed). - * @param stopSparkContext Stop the associated SparkContext or not * + * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * will be stopped regardless of whether this StreamingContext has been + * started. */ def stop(stopSparkContext: Boolean = true): Unit = synchronized { stop(stopSparkContext, false) @@ -482,25 +484,27 @@ class StreamingContext private[streaming] ( /** * Stop the execution of the streams, with option of ensuring all received data * has been processed. - * @param stopSparkContext Stop the associated SparkContext or not - * @param stopGracefully Stop gracefully by waiting for the processing of all + * + * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * will be stopped regardless of whether this StreamingContext has been + * started. + * @param stopGracefully if true, stops gracefully by waiting for the processing of all * received data to be completed */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - // Warn (but not fail) if context is stopped twice, - // or context is stopped before starting - if (state == Initialized) { - logWarning("StreamingContext has not been started yet") - return + state match { + case Initialized => logWarning("StreamingContext has not been started yet") + case Stopped => logWarning("StreamingContext has already been stopped") + case Started => + scheduler.stop(stopGracefully) + logInfo("StreamingContext stopped successfully") + waiter.notifyStop() } - if (state == Stopped) { - logWarning("StreamingContext has already been stopped") - return - } // no need to throw an exception as its okay to stop twice - scheduler.stop(stopGracefully) - logInfo("StreamingContext stopped successfully") - waiter.notifyStop() + // Even if the streaming context has not been started, we still need to stop the SparkContext. + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). if (stopSparkContext) sc.stop() + // The state should always be Stopped after calling `stop()`, even if we haven't started yet: state = Stopped } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 8152b7542ac57..55d6cf6a783ea 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -120,14 +120,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) - files.zip(fileRDDs).foreach { case (file, rdd) => { + val fileRDDs = files.map(file =>{ + val rdd = context.sparkContext.newAPIHadoopFile[K, V, F](file) if (rdd.partitions.size == 0) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + "Refer to the streaming programming guide for more details.") } - }} + rdd + }) new UnionRDD(context.sparkContext, fileRDDs) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 391e40924f38a..3e67161363e50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -17,13 +17,13 @@ package org.apache.spark.streaming.dstream -import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo /** @@ -39,9 +39,6 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** Keeps all received blocks information */ - private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] - /** This is an unique identifier for the network input stream. */ val id = ssc.getNewReceiverStreamId() @@ -57,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont def stop() {} - /** Ask ReceiverInputTracker for received data blocks and generates RDDs with them. */ + /** + * Generates RDDs with blocks received by the receiver of this stream. */ override def compute(validTime: Time): Option[RDD[T]] = { - // If this is called for any time before the start time of the context, - // then this returns an empty RDD. This may happen when recovering from a - // master failure - if (validTime >= graph.startTime) { - val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) - receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map(_.blockId.asInstanceOf[BlockId]) - Some(new BlockRDD[T](ssc.sc, blockIds)) - } else { - Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) - } - } + val blockRDD = { - /** Get information on received blocks. */ - private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) + if (validTime < graph.startTime) { + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // driver failure without any write ahead log to recover pre-failure data. + new BlockRDD[T](ssc.sc, Array.empty) + } else { + // Otherwise, ask the tracker for all the blocks that have been allocated to this stream + // for this batch + val blockInfos = + ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty) + val blockStoreResults = blockInfos.map { _.blockStoreResult } + val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Check whether all the results are of the same type + val resultTypes = blockStoreResults.map { _.getClass }.distinct + if (resultTypes.size > 1) { + logWarning("Multiple result types in block information, WAL information will be ignored.") + } + + // If all the results are of type WriteAheadLogBasedStoreResult, then create + // WriteAheadLogBackedBlockRDD else create simple BlockRDD. + if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { + val logSegments = blockStoreResults.map { + _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + }.toArray + // Since storeInBlockManager = false, the storage level does not matter. + new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, + blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + } else { + new BlockRDD[T](ssc.sc, blockIds) + } + } + } + Some(blockRDD) } /** @@ -85,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ private[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) - receivedBlockInfo --= oldReceivedBlocks.keys - logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala new file mode 100644 index 0000000000000..dd1e96334952f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.rdd + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark._ +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, WriteAheadLogRandomReader} + +/** + * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. + * It contains information about the id of the blocks having this partition's data and + * the segment of the write ahead log that backs the partition. + * @param index index of the partition + * @param blockId id of the block having the partition data + * @param segment segment of the write ahead log having the partition data + */ +private[streaming] +class WriteAheadLogBackedBlockRDDPartition( + val index: Int, + val blockId: BlockId, + val segment: WriteAheadLogFileSegment) + extends Partition + + +/** + * This class represents a special case of the BlockRDD where the data blocks in + * the block manager are also backed by segments in write ahead logs. For reading + * the data, this RDD first looks up the blocks by their ids in the block manager. + * If it does not find them, it looks up the corresponding file segment. + * + * @param sc SparkContext + * @param blockIds Ids of the blocks that contains this RDD's data + * @param segments Segments in write ahead logs that contain this RDD's data + * @param storeInBlockManager Whether to store in the block manager after reading from the segment + * @param storageLevel storage level to store when storing in block manager + * (applicable when storeInBlockManager = true) + */ +private[streaming] +class WriteAheadLogBackedBlockRDD[T: ClassTag]( + @transient sc: SparkContext, + @transient blockIds: Array[BlockId], + @transient segments: Array[WriteAheadLogFileSegment], + storeInBlockManager: Boolean, + storageLevel: StorageLevel) + extends BlockRDD[T](sc, blockIds) { + + require( + blockIds.length == segments.length, + s"Number of block ids (${blockIds.length}) must be " + + s"the same as number of segments (${segments.length}})!") + + // Hadoop configuration is not serializable, so broadcast it as a serializable. + @transient private val hadoopConfig = sc.hadoopConfiguration + private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + + override def getPartitions: Array[Partition] = { + assertValid() + Array.tabulate(blockIds.size) { i => + new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) + } + } + + /** + * Gets the partition data by getting the corresponding block from the block manager. + * If the block does not exist, then the data is read from the corresponding segment + * in write ahead log files. + */ + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + assertValid() + val hadoopConf = broadcastedHadoopConf.value + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] + val blockId = partition.blockId + blockManager.get(blockId) match { + case Some(block) => // Data is in Block Manager + val iterator = block.data.asInstanceOf[Iterator[T]] + logDebug(s"Read partition data of $this from block manager, block $blockId") + iterator + case None => // Data not found in Block Manager, grab it from write ahead log file + val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf) + val dataRead = reader.read(partition.segment) + reader.close() + logInfo(s"Read partition data of $this from write ahead log, segment ${partition.segment}") + if (storeInBlockManager) { + blockManager.putBytes(blockId, dataRead, storageLevel) + logDebug(s"Stored partition data of $this into block manager with level $storageLevel") + dataRead.rewind() + } + blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + } + } + + /** + * Get the preferred location of the partition. This returns the locations of the block + * if it is present in the block manager, else it returns the location of the + * corresponding segment in HDFS. + */ + override def getPreferredLocations(split: Partition): Seq[String] = { + val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] + val blockLocations = getBlockIdLocations().get(partition.blockId) + def segmentLocations = HdfsUtils.getFileSegmentLocations( + partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) + blockLocations.getOrElse(segmentLocations) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala new file mode 100644 index 0000000000000..47968afef2dbf --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.language.existentials + +/** Trait representing a received block */ +private[streaming] sealed trait ReceivedBlock + +/** class representing a block received as an ArrayBuffer */ +private[streaming] case class ArrayBufferBlock(arrayBuffer: ArrayBuffer[_]) extends ReceivedBlock + +/** class representing a block received as an Iterator */ +private[streaming] case class IteratorBlock(iterator: Iterator[_]) extends ReceivedBlock + +/** class representing a block received as an ByteBuffer */ +private[streaming] case class ByteBufferBlock(byteBuffer: ByteBuffer) extends ReceivedBlock diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala new file mode 100644 index 0000000000000..fdf995320beb4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.{existentials, postfixOps} + +import WriteAheadLogBasedBlockHandler._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage._ +import org.apache.spark.streaming.util.{Clock, SystemClock, WriteAheadLogFileSegment, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait that represents the metadata related to storage of blocks */ +private[streaming] trait ReceivedBlockStoreResult { + def blockId: StreamBlockId // Any implementation of this trait will store a block id +} + +/** Trait that represents a class that handles the storage of blocks received by receiver */ +private[streaming] trait ReceivedBlockHandler { + + /** Store a received block with the given block id and return related metadata */ + def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult + + /** Cleanup old blocks older than the given threshold time */ + def cleanupOldBlock(threshTime: Long) +} + + +/** + * Implementation of [[org.apache.spark.streaming.receiver.ReceivedBlockStoreResult]] + * that stores the metadata related to storage of blocks using + * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] + */ +private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) + extends ReceivedBlockStoreResult + + +/** + * Implementation of a [[org.apache.spark.streaming.receiver.ReceivedBlockHandler]] which + * stores the received blocks into a block manager with the specified storage level. + */ +private[streaming] class BlockManagerBasedBlockHandler( + blockManager: BlockManager, storageLevel: StorageLevel) + extends ReceivedBlockHandler with Logging { + + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + val putResult: Seq[(BlockId, BlockStatus)] = block match { + case ArrayBufferBlock(arrayBuffer) => + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + case IteratorBlock(iterator) => + blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + case ByteBufferBlock(byteBuffer) => + blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) + case o => + throw new SparkException( + s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") + } + if (!putResult.map { _._1 }.contains(blockId)) { + throw new SparkException( + s"Could not store $blockId to block manager with storage level $storageLevel") + } + BlockManagerBasedStoreResult(blockId) + } + + def cleanupOldBlock(threshTime: Long) { + // this is not used as blocks inserted into the BlockManager are cleared by DStream's clearing + // of BlockRDDs. + } +} + + +/** + * Implementation of [[org.apache.spark.streaming.receiver.ReceivedBlockStoreResult]] + * that stores the metadata related to storage of blocks using + * [[org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler]] + */ +private[streaming] case class WriteAheadLogBasedStoreResult( + blockId: StreamBlockId, + segment: WriteAheadLogFileSegment + ) extends ReceivedBlockStoreResult + + +/** + * Implementation of a [[org.apache.spark.streaming.receiver.ReceivedBlockHandler]] which + * stores the received blocks in both, a write ahead log and a block manager. + */ +private[streaming] class WriteAheadLogBasedBlockHandler( + blockManager: BlockManager, + streamId: Int, + storageLevel: StorageLevel, + conf: SparkConf, + hadoopConf: Configuration, + checkpointDir: String, + clock: Clock = new SystemClock + ) extends ReceivedBlockHandler with Logging { + + private val blockStoreTimeout = conf.getInt( + "spark.streaming.receiver.blockStoreTimeout", 30).seconds + private val rollingInterval = conf.getInt( + "spark.streaming.receiver.writeAheadLog.rollingInterval", 60) + private val maxFailures = conf.getInt( + "spark.streaming.receiver.writeAheadLog.maxFailures", 3) + + // Manages rolling log files + private val logManager = new WriteAheadLogManager( + checkpointDirToLogDir(checkpointDir, streamId), + hadoopConf, rollingInterval, maxFailures, + callerName = this.getClass.getSimpleName, + clock = clock + ) + + // For processing futures used in parallel block storing into block manager and write ahead log + // # threads = 2, so that both writing to BM and WAL can proceed in parallel + implicit private val executionContext = ExecutionContext.fromExecutorService( + Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) + + /** + * This implementation stores the block into the block manager as well as a write ahead log. + * It does this in parallel, using Scala Futures, and returns only after the block has + * been stored in both places. + */ + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + // Serialize the block so that it can be inserted into both + val serializedBlock = block match { + case ArrayBufferBlock(arrayBuffer) => + blockManager.dataSerialize(blockId, arrayBuffer.iterator) + case IteratorBlock(iterator) => + blockManager.dataSerialize(blockId, iterator) + case ByteBufferBlock(byteBuffer) => + byteBuffer + case _ => + throw new Exception(s"Could not push $blockId to block manager, unexpected block type") + } + + // Store the block in block manager + val storeInBlockManagerFuture = Future { + val putResult = + blockManager.putBytes(blockId, serializedBlock, storageLevel, tellMaster = true) + if (!putResult.map { _._1 }.contains(blockId)) { + throw new SparkException( + s"Could not store $blockId to block manager with storage level $storageLevel") + } + } + + // Store the block in write ahead log + val storeInWriteAheadLogFuture = Future { + logManager.writeToLog(serializedBlock) + } + + // Combine the futures, wait for both to complete, and return the write ahead log segment + val combinedFuture = for { + _ <- storeInBlockManagerFuture + fileSegment <- storeInWriteAheadLogFuture + } yield fileSegment + val segment = Await.result(combinedFuture, blockStoreTimeout) + WriteAheadLogBasedStoreResult(blockId, segment) + } + + def cleanupOldBlock(threshTime: Long) { + logManager.cleanupOldLogs(threshTime) + } + + def stop() { + logManager.stop() + } +} + +private[streaming] object WriteAheadLogBasedBlockHandler { + def checkpointDirToLogDir(checkpointDir: String, streamId: Int): String = { + new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 53a3e6200e340..5360412330d37 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -25,16 +25,13 @@ import scala.concurrent.Await import akka.actor.{Actor, Props} import akka.pattern.ask - import com.google.common.base.Throwables - -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.scheduler.DeregisterReceiver -import org.apache.spark.streaming.scheduler.AddBlock -import org.apache.spark.streaming.scheduler.RegisterReceiver +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.WriteAheadLogFileSegment +import org.apache.spark.util.{AkkaUtils, Utils} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -44,12 +41,26 @@ import org.apache.spark.streaming.scheduler.RegisterReceiver */ private[streaming] class ReceiverSupervisorImpl( receiver: Receiver[_], - env: SparkEnv + env: SparkEnv, + hadoopConf: Configuration, + checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { - private val blockManager = env.blockManager + private val receivedBlockHandler: ReceivedBlockHandler = { + if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (checkpointDirOption.isEmpty) { + throw new SparkException( + "Cannot enable receiver write-ahead log without checkpoint directory set. " + + "Please use streamingContext.checkpoint() to set the checkpoint directory. " + + "See documentation for more details.") + } + new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId, + receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get) + } else { + new BlockManagerBasedBlockHandler(env.blockManager, receiver.storageLevel) + } + } - private val storageLevel = receiver.storageLevel /** Remote Akka actor for the ReceiverTracker */ private val trackerActor = { @@ -105,47 +116,50 @@ private[streaming] class ReceiverSupervisorImpl( /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], - optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata) + pushAndReportBlock(ArrayBufferBlock(arrayBuffer), metadataOption, blockIdOption) } /** Store a iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], - optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, -1, optionalMetadata) + pushAndReportBlock(IteratorBlock(iterator), metadataOption, blockIdOption) } /** Store the bytes of received data as a data block into Spark's memory. */ def pushBytes( bytes: ByteBuffer, - optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putBytes(blockId, bytes, storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, -1, optionalMetadata) + pushAndReportBlock(ByteBufferBlock(bytes), metadataOption, blockIdOption) } - /** Report pushed block */ - def reportPushedBlock(blockId: StreamBlockId, numRecords: Long, optionalMetadata: Option[Any]) { - val blockInfo = ReceivedBlockInfo(streamId, blockId, numRecords, optionalMetadata.orNull) - trackerActor ! AddBlock(blockInfo) - logDebug("Reported block " + blockId) + /** Store block and report it to driver */ + def pushAndReportBlock( + receivedBlock: ReceivedBlock, + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] + ) { + val blockId = blockIdOption.getOrElse(nextBlockId) + val numRecords = receivedBlock match { + case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size + case _ => -1 + } + + val time = System.currentTimeMillis + val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) + logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") + + val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) + val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) + Await.result(future, askTimeout) + logDebug(s"Reported block $blockId") } /** Report error to the receiver tracker */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index a68aecb881117..92dc113f397ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.streaming.Time import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.Time /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 7d73ada12d107..39b66e1130768 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Wait until all the received blocks in the network input tracker has // been consumed by network input DStreams, and jobs have been generated with them logInfo("Waiting for all received blocks to be consumed for job generation") - while(!hasTimedOut && jobScheduler.receiverTracker.hasMoreReceivedBlockIds) { + while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) { Thread.sleep(pollTime) } logInfo("Waited for all received blocks to be consumed for job generation") @@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - Try(graph.generateJobs(time)) match { + // Set the SparkEnv in this thread, so that job generation code can access the environment + // Example: BlockRDDs are created in this thread, and it needs to access BlockManager + // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. + SparkEnv.set(ssc.env) + Try { + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch + graph.generateJobs(time) // generate jobs using allocated block + } match { case Success(jobs) => - val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => - val streamId = stream.id - val receivedBlockInfo = stream.getReceivedBlockInfo(time) - (streamId, receivedBlockInfo) - }.toMap - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + val receivedBlockInfos = + jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) + jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index a69d74362173e..8c15a75b1b0e0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{ArrayBuffer, HashSet} +import scala.collection.mutable.HashSet + import org.apache.spark.streaming.Time /** Class representing a set of Jobs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala new file mode 100644 index 0000000000000..94beb590f52d6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming.receiver.ReceivedBlockStoreResult + +/** Information about blocks received by the receiver */ +private[streaming] case class ReceivedBlockInfo( + streamId: Int, + numRecords: Long, + blockStoreResult: ReceivedBlockStoreResult + ) + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala new file mode 100644 index 0000000000000..5f5e1909908d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait representing any event in the ReceivedBlockTracker that updates its state. */ +private[streaming] sealed trait ReceivedBlockTrackerLogEvent + +private[streaming] case class BlockAdditionEvent(receivedBlockInfo: ReceivedBlockInfo) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: AllocatedBlocks) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchCleanupEvent(times: Seq[Time]) + extends ReceivedBlockTrackerLogEvent + + +/** Class representing the blocks of all the streams allocated to a batch */ +private[streaming] +case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { + def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = { + streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty) + } +} + +/** + * Class that keep track of all the received blocks, and allocate them to batches + * when required. All actions taken by this class can be saved to a write ahead log + * (if a checkpoint directory has been provided), so that the state of the tracker + * (received blocks and block-to-batch allocations) can be recovered after driver failure. + * + * Note that when any instance of this class is created with a checkpoint directory, + * it will try reading events from logs in the directory. + */ +private[streaming] class ReceivedBlockTracker( + conf: SparkConf, + hadoopConf: Configuration, + streamIds: Seq[Int], + clock: Clock, + checkpointDirOption: Option[String]) + extends Logging { + + private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] + + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] + private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] + + private val logManagerRollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + private val logManagerOption = checkpointDirOption.map { checkpointDir => + new WriteAheadLogManager( + ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), + hadoopConf, + rollingIntervalSecs = logManagerRollingIntervalSecs, + callerName = "ReceivedBlockHandlerMaster", + clock = clock + ) + } + + private var lastAllocatedBatchTime: Time = null + + // Recover block information from write ahead logs + recoverFromWriteAheadLogs() + + /** Add received block. This event will get written to the write ahead log (if enabled). */ + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + try { + writeToLog(BlockAdditionEvent(receivedBlockInfo)) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + true + } catch { + case e: Exception => + logError(s"Error adding block $receivedBlockInfo", e) + false + } + } + + /** + * Allocate all unallocated blocks to the given batch. + * This event will get written to the write ahead log (if enabled). + */ + def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { + if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { + val streamIdToBlocks = streamIds.map { streamId => + (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + }.toMap + val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) + writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + timeToAllocatedBlocks(batchTime) = allocatedBlocks + lastAllocatedBatchTime = batchTime + allocatedBlocks + } else { + throw new SparkException(s"Unexpected allocation of blocks, " + + s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + } + } + + /** Get the blocks allocated to the given batch. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = synchronized { + timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks }.getOrElse(Map.empty) + } + + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + timeToAllocatedBlocks.get(batchTime).map { + _.getBlocksOfStream(streamId) + }.getOrElse(Seq.empty) + } + } + + /** Check if any blocks are left to be allocated to batches. */ + def hasUnallocatedReceivedBlocks: Boolean = synchronized { + !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty) + } + + /** + * Get blocks that have been added but not yet allocated to any batch. This method + * is primarily used for testing. + */ + def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized { + getReceivedBlockQueue(streamId).toSeq + } + + /** Clean up block information of old batches. */ + def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + assert(cleanupThreshTime.milliseconds < clock.currentTime()) + val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq + logInfo("Deleting batches " + timesToCleanup) + writeToLog(BatchCleanupEvent(timesToCleanup)) + timeToAllocatedBlocks --= timesToCleanup + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + log + } + + /** Stop the block tracker. */ + def stop() { + logManagerOption.foreach { _.stop() } + } + + /** + * Recover all the tracker actions from the write ahead logs to recover the state (unallocated + * and allocated block info) prior to failure. + */ + private def recoverFromWriteAheadLogs(): Unit = synchronized { + // Insert the recovered block information + def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { + logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + + // Insert the recovered block-to-batch allocations and clear the queue of received blocks + // (when the blocks were originally allocated to the batch, the queue must have been cleared). + def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { + logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + + s"${allocatedBlocks.streamIdToAllocatedBlocks}") + streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + lastAllocatedBatchTime = batchTime + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + } + + // Cleanup the batch allocations + def cleanupBatches(batchTimes: Seq[Time]) { + logTrace(s"Recovery: Cleaning up batches $batchTimes") + timeToAllocatedBlocks --= batchTimes + } + + logManagerOption.foreach { logManager => + logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") + logManager.readFromLog().foreach { byteBuffer => + logTrace("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + case BlockAdditionEvent(receivedBlockInfo) => + insertAddedBlock(receivedBlockInfo) + case BatchAllocationEvent(time, allocatedBlocks) => + insertAllocatedBatch(time, allocatedBlocks) + case BatchCleanupEvent(batchTimes) => + cleanupBatches(batchTimes) + } + } + } + } + + /** Write an update to the tracker to the write ahead log */ + private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + logDebug(s"Writing to log $record") + logManagerOption.foreach { logManager => + logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } + } + + /** Get the queue of received blocks belonging to a particular stream */ + private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { + streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue) + } +} + +private[streaming] object ReceivedBlockTracker { + def checkpointDirToLogDir(checkpointDir: String): String = { + new Path(checkpointDir, "receivedBlockMetadata").toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 7149dbc12a365..1c3984d968d20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,24 +17,16 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} + +import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import akka.actor._ -import org.apache.spark.{Logging, SparkEnv, SparkException} + +import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.SparkContext._ -import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} -import org.apache.spark.util.AkkaUtils - -/** Information about blocks received by the receiver */ -private[streaming] case class ReceivedBlockInfo( - streamId: Int, - blockId: StreamBlockId, - numRecords: Long, - metadata: Any - ) /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -57,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * This class manages the execution of the receivers of NetworkInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. + * + * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging { - val receiverInputStreams = ssc.graph.getReceiverInputStreams() - val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverLauncher() - val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] - val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - val timeout = AkkaUtils.askTimeout(ssc.conf) - val listenerBus = ssc.scheduler.listenerBus + private val receiverInputStreams = ssc.graph.getReceiverInputStreams() + private val receiverInputStreamIds = receiverInputStreams.map { _.id } + private val receiverExecutor = new ReceiverLauncher() + private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] + private val receivedBlockTracker = new ReceivedBlockTracker( + ssc.sparkContext.conf, + ssc.sparkContext.hadoopConfiguration, + receiverInputStreamIds, + ssc.scheduler.clock, + Option(ssc.checkpointDir) + ) + private val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped - var actor: ActorRef = null - var currentTime: Time = null + private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ def start() = synchronized { @@ -84,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { if (!receiverInputStreams.isEmpty) { actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), "ReceiverTracker") - receiverExecutor.start() + if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } @@ -93,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def stop() = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop() // Finally, stop the actor ssc.env.actorSystem.stop(actor) actor = null + receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { - val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) - logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") - receivedBlockInfo.toArray + /** Allocate all unallocated blocks to the given batch. */ + def allocateBlocksToBatch(batchTime: Time): Unit = { + if (receiverInputStreams.nonEmpty) { + receivedBlockTracker.allocateBlocksToBatch(batchTime) + } + } + + /** Get the blocks for the given batch and all input streams. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = { + receivedBlockTracker.getBlocksOfBatch(batchTime) } - private def getReceivedBlockInfoQueue(streamId: Int) = { - receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) + } + } + + /** Clean up metadata older than the given threshold time */ + def cleanupOldMetadata(cleanupThreshTime: Time) { + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) } /** Register a receiver */ - def registerReceiver( + private def registerReceiver( streamId: Int, typ: String, host: String, receiverActor: ActorRef, sender: ActorRef ) { - if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + if (!receiverInputStreamIds.contains(streamId)) { + throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) } /** Deregister a receiver */ - def deregisterReceiver(streamId: Int, message: String, error: String) { + private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) @@ -140,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -150,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Add new blocks for the given stream */ - def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { - getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockId) + private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { + receivedBlockTracker.addBlock(receivedBlockInfo) } /** Report error sent by a receiver */ - def reportError(streamId: Int, message: String, error: String) { + private def reportError(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(lastErrorMessage = message, lastError = error) @@ -166,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -176,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Check if any blocks are left to be processed */ - def hasMoreReceivedBlockIds: Boolean = { - !receivedBlockInfo.values.forall(_.isEmpty) + def hasUnallocatedBlocks: Boolean = { + receivedBlockTracker.hasUnallocatedReceivedBlocks } /** Actor to receive messages from the receivers. */ @@ -187,7 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true case AddBlock(receivedBlockInfo) => - addBlocks(receivedBlockInfo) + sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -202,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { + SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") @@ -252,6 +262,9 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ssc.sc.makeRDD(receivers, receivers.size) } + val checkpointDirOption = Option(ssc.checkpointDir) + val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { if (!iterator.hasNext) { @@ -259,9 +272,10 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { "Could not start receiver as object not found.") } val receiver = iterator.next() - val executor = new ReceiverSupervisorImpl(receiver, SparkEnv.get) - executor.start() - executor.awaitTermination() + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() } // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. @@ -271,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index 39145a3ab081a..7cd867ce34b87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -41,13 +41,7 @@ class SystemClock() extends Clock { return currentTime } - val pollTime = { - if (waitTime / 10.0 > minPollTime) { - (waitTime / 10.0).toLong - } else { - minPollTime - } - } + val pollTime = math.max(waitTime / 10.0, minPollTime).toLong while (true) { currentTime = System.currentTimeMillis() @@ -55,12 +49,7 @@ class SystemClock() extends Clock { if (waitTime <= 0) { return currentTime } - val sleepTime = - if (waitTime < pollTime) { - waitTime - } else { - pollTime - } + val sleepTime = math.min(waitTime, pollTime) Thread.sleep(sleepTime) } -1 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 491f1175576e6..27a28bab83ed5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -52,12 +52,14 @@ private[streaming] object HdfsUtils { } } - def getBlockLocations(path: String, conf: Configuration): Option[Array[String]] = { + /** Get the locations of the HDFS blocks containing the given file segment. */ + def getFileSegmentLocations( + path: String, offset: Long, length: Long, conf: Configuration): Array[String] = { val dfsPath = new Path(path) val dfs = getFileSystemForPath(dfsPath, conf) val fileStatus = dfs.getFileStatus(dfsPath) - val blockLocs = Option(dfs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen)) - blockLocs.map(_.flatMap(_.getHosts)) + val blockLocs = Option(dfs.getFileBlockLocations(fileStatus, offset, length)) + blockLocs.map(_.flatMap(_.getHosts)).getOrElse(Array.empty) } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala index 92bad7a882a65..003989092a42a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala @@ -52,4 +52,3 @@ private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configura HdfsUtils.checkState(!closed, "Stream is closed. Create a new Reader to read from the file.") } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6c8bb50145367..dbab685dc3511 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,18 +17,19 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.SparkContext._ +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.language.existentials +import scala.reflect.ClassTag import util.ManualClock -import org.apache.spark.{SparkException, SparkConf} -import org.apache.spark.streaming.dstream.{WindowedDStream, DStream} -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import scala.collection.mutable +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} class BasicOperationsSuite extends TestSuiteBase { test("map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala new file mode 100644 index 0000000000000..9efe15d01ed0c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.postfixOps + +import akka.actor.{ActorSystem, Props} +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.storage._ +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.util._ +import org.apache.spark.util.AkkaUtils +import WriteAheadLogBasedBlockHandler._ +import WriteAheadLogSuite._ + +class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val hadoopConf = new Configuration() + val storageLevel = StorageLevel.MEMORY_ONLY_SER + val streamId = 1 + val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) + val serializer = new KryoSerializer(conf) + val manualClock = new ManualClock + val blockManagerSize = 10000000 + + var actorSystem: ActorSystem = null + var blockManagerMaster: BlockManagerMaster = null + var blockManager: BlockManager = null + var tempDirectory: File = null + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "test", "localhost", 0, conf = conf, securityManager = securityMgr) + this.actorSystem = actorSystem + conf.set("spark.driver.port", boundPort.toString) + + blockManagerMaster = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + conf, true) + + blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManagerSize, conf, mapOutputTracker, shuffleManager, + new NioBlockTransferService(conf, securityMgr), securityMgr) + blockManager.initialize("app-id") + + tempDirectory = Files.createTempDir() + manualClock.setTime(0) + } + + after { + if (blockManager != null) { + blockManager.stop() + blockManager = null + } + if (blockManagerMaster != null) { + blockManagerMaster.stop() + blockManagerMaster = null + } + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + } + + test("BlockManagerBasedBlockHandler - store blocks") { + withBlockManagerBasedBlockHandler { handler => + testBlockStoring(handler) { case (data, blockIds, storeResults) => + // Verify the data in block manager is correct + val storedData = blockIds.flatMap { blockId => + blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + }.toList + storedData shouldEqual data + + // Verify that the store results are instances of BlockManagerBasedStoreResult + assert( + storeResults.forall { _.isInstanceOf[BlockManagerBasedStoreResult] }, + "Unexpected store result type" + ) + } + } + } + + test("BlockManagerBasedBlockHandler - handle errors in storing block") { + withBlockManagerBasedBlockHandler { handler => + testErrorHandling(handler) + } + } + + test("WriteAheadLogBasedBlockHandler - store blocks") { + withWriteAheadLogBasedBlockHandler { handler => + testBlockStoring(handler) { case (data, blockIds, storeResults) => + // Verify the data in block manager is correct + val storedData = blockIds.flatMap { blockId => + blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + }.toList + storedData shouldEqual data + + // Verify that the store results are instances of WriteAheadLogBasedStoreResult + assert( + storeResults.forall { _.isInstanceOf[WriteAheadLogBasedStoreResult] }, + "Unexpected store result type" + ) + // Verify the data in write ahead log files is correct + val fileSegments = storeResults.map { _.asInstanceOf[WriteAheadLogBasedStoreResult].segment} + val loggedData = fileSegments.flatMap { segment => + val reader = new WriteAheadLogRandomReader(segment.path, hadoopConf) + val bytes = reader.read(segment) + reader.close() + blockManager.dataDeserialize(generateBlockId(), bytes).toList + } + loggedData shouldEqual data + } + } + } + + test("WriteAheadLogBasedBlockHandler - handle errors in storing block") { + withWriteAheadLogBasedBlockHandler { handler => + testErrorHandling(handler) + } + } + + test("WriteAheadLogBasedBlockHandler - cleanup old blocks") { + withWriteAheadLogBasedBlockHandler { handler => + val blocks = Seq.tabulate(10) { i => IteratorBlock(Iterator(1 to i)) } + storeBlocks(handler, blocks) + + val preCleanupLogFiles = getWriteAheadLogFiles() + preCleanupLogFiles.size should be > 1 + + // this depends on the number of blocks inserted using generateAndStoreData() + manualClock.currentTime() shouldEqual 5000L + + val cleanupThreshTime = 3000L + handler.cleanupOldBlock(cleanupThreshTime) + eventually(timeout(10000 millis), interval(10 millis)) { + getWriteAheadLogFiles().size should be < preCleanupLogFiles.size + } + } + } + + /** + * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded + * using the given verification function + */ + private def testBlockStoring(receivedBlockHandler: ReceivedBlockHandler) + (verifyFunc: (Seq[String], Seq[StreamBlockId], Seq[ReceivedBlockStoreResult]) => Unit) { + val data = Seq.tabulate(100) { _.toString } + + def storeAndVerify(blocks: Seq[ReceivedBlock]) { + blocks should not be empty + val (blockIds, storeResults) = storeBlocks(receivedBlockHandler, blocks) + withClue(s"Testing with ${blocks.head.getClass.getSimpleName}s:") { + // Verify returns store results have correct block ids + (storeResults.map { _.blockId }) shouldEqual blockIds + + // Call handler-specific verification function + verifyFunc(data, blockIds, storeResults) + } + } + + def dataToByteBuffer(b: Seq[String]) = blockManager.dataSerialize(generateBlockId, b.iterator) + + val blocks = data.grouped(10).toSeq + + storeAndVerify(blocks.map { b => IteratorBlock(b.toIterator) }) + storeAndVerify(blocks.map { b => ArrayBufferBlock(new ArrayBuffer ++= b) }) + storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b)) }) + } + + /** Test error handling when blocks that cannot be stored */ + private def testErrorHandling(receivedBlockHandler: ReceivedBlockHandler) { + // Handle error in iterator (e.g. divide-by-zero error) + intercept[Exception] { + val iterator = (10 to (-10, -1)).toIterator.map { _ / 0 } + receivedBlockHandler.storeBlock(StreamBlockId(1, 1), IteratorBlock(iterator)) + } + + // Handler error in block manager storing (e.g. too big block) + intercept[SparkException] { + val byteBuffer = ByteBuffer.wrap(new Array[Byte](blockManagerSize + 1)) + receivedBlockHandler.storeBlock(StreamBlockId(1, 1), ByteBufferBlock(byteBuffer)) + } + } + + /** Instantiate a BlockManagerBasedBlockHandler and run a code with it */ + private def withBlockManagerBasedBlockHandler(body: BlockManagerBasedBlockHandler => Unit) { + body(new BlockManagerBasedBlockHandler(blockManager, storageLevel)) + } + + /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ + private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { + val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, + storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) + try { + body(receivedBlockHandler) + } finally { + receivedBlockHandler.stop() + } + } + + /** Store blocks using a handler */ + private def storeBlocks( + receivedBlockHandler: ReceivedBlockHandler, + blocks: Seq[ReceivedBlock] + ): (Seq[StreamBlockId], Seq[ReceivedBlockStoreResult]) = { + val blockIds = Seq.fill(blocks.size)(generateBlockId()) + val storeResults = blocks.zip(blockIds).map { + case (block, id) => + manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf + logDebug("Inserting block " + id) + receivedBlockHandler.storeBlock(id, block) + }.toList + logDebug("Done inserting") + (blockIds, storeResults) + } + + private def getWriteAheadLogFiles(): Seq[String] = { + getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) + } + + private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala new file mode 100644 index 0000000000000..fd9c97f551c62 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} +import scala.util.Random + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogSuite._ +import org.apache.spark.util.Utils + +class ReceivedBlockTrackerSuite + extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + + val hadoopConf = new Configuration() + val akkaTimeout = 10 seconds + val streamId = 1 + + var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() + var checkpointDirectory: File = null + + before { + checkpointDirectory = Files.createTempDir() + } + + after { + allReceivedBlockTrackers.foreach { _.stop() } + if (checkpointDirectory != null && checkpointDirectory.exists()) { + FileUtils.deleteDirectory(checkpointDirectory) + checkpointDirectory = null + } + } + + test("block addition, and block to batch allocation") { + val receivedBlockTracker = createTracker(enableCheckpoint = false) + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + val blockInfos = generateBlockInfos() + blockInfos.map(receivedBlockTracker.addBlock) + + // Verify added blocks are unallocated blocks + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Allocate the blocks to a batch and verify that all of them have been allocated + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty + + // Allocate no blocks to another batch + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + + // Verify that batch 2 cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(2) + } + + // Verify that older batches cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(1) + } + } + + test("block addition, block to batch allocation and cleanup with write ahead log") { + val manualClock = new ManualClock + conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) + + // Set the time increment level to twice the rotation interval so that every increment creates + // a new log file + val timeIncrementMillis = 2000L + def incrementTime() { + manualClock.addToTime(timeIncrementMillis) + } + + // Generate and add blocks to the given tracker + def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = { + val blockInfos = generateBlockInfos() + blockInfos.map(tracker.addBlock) + blockInfos + } + + // Print the data present in the log ahead files in the log directory + def printLogFiles(message: String) { + val fileContents = getWriteAheadLogFiles().map { file => + (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}") + }.mkString("\n") + logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") + } + + // Start tracker and add blocks + val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + val blockInfos1 = addBlockInfos(tracker1) + tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Verify whether write ahead log has correct contents + val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent) + getWrittenLogData() shouldEqual expectedWrittenData1 + getWriteAheadLogFiles() should have size 1 + + // Restart tracker and verify recovered list of unallocated blocks + incrementTime() + val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Allocate blocks to batch and verify whether the unallocated blocks got allocated + val batchTime1 = manualClock.currentTime + tracker2.allocateBlocksToBatch(batchTime1) + tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + + // Add more blocks and allocate to another batch + incrementTime() + val batchTime2 = manualClock.currentTime + val blockInfos2 = addBlockInfos(tracker2) + tracker2.allocateBlocksToBatch(batchTime2) + tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + + // Verify whether log has correct contents + val expectedWrittenData2 = expectedWrittenData1 ++ + Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ + blockInfos2.map(BlockAdditionEvent) ++ + Seq(createBatchAllocation(batchTime2, blockInfos2)) + getWrittenLogData() shouldEqual expectedWrittenData2 + + // Restart tracker and verify recovered state + incrementTime() + val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + tracker3.getUnallocatedBlocks(streamId) shouldBe empty + + // Cleanup first batch but not second batch + val oldestLogFile = getWriteAheadLogFiles().head + incrementTime() + tracker3.cleanupOldBatches(batchTime2) + + // Verify that the batch allocations have been cleaned, and the act has been written to log + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty + getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1)) + + // Verify that at least one log file gets deleted + eventually(timeout(10 seconds), interval(10 millisecond)) { + getWriteAheadLogFiles() should not contain oldestLogFile + } + printLogFiles("After cleanup") + + // Restart tracker and verify recovered state, specifically whether info about the first + // batch has been removed, but not the second batch + incrementTime() + val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker4.getUnallocatedBlocks(streamId) shouldBe empty + tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned + tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + } + + /** + * Create tracker object with the optional provided clock. Use fake clock if you + * want to control time by manually incrementing it to test log cleanup. + */ + def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + allReceivedBlockTrackers += tracker + tracker + } + + /** Generate blocks infos using random ids */ + def generateBlockInfos(): Seq[ReceivedBlockInfo] = { + List.fill(5)(ReceivedBlockInfo(streamId, 0, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + } + + /** Get all the data written in the given write ahead log file. */ + def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { + getWrittenLogData(Seq(logFile)) + } + + /** + * Get all the data written in the given write ahead log files. By default, it will read all + * files in the test log directory. + */ + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + logFiles.flatMap { + file => new WriteAheadLogReader(file, hadoopConf).toSeq + }.map { byteBuffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.toList + } + + /** Get all the write ahead log files in the test directory */ + def getWriteAheadLogFiles(): Seq[String] = { + import ReceivedBlockTracker._ + val logDir = checkpointDirToLogDir(checkpointDirectory.toString) + getLogFilesInDirectory(logDir).map { _.toString } + } + + /** Create batch allocation object from the given info */ + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) + } + + /** Create batch cleanup object from the given info */ + def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { + BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) + } + + implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds) + + implicit def timeToMillis(time: Time): Long = time.milliseconds +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 655cec1573f58..4b49c4d251645 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -133,11 +133,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.stop() } - test("stop before start and start after stop") { + test("stop before start") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() // stop before start should not throw exception - ssc.start() + } + + test("start after stop") { + // Regression test for SPARK-4301 + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register() ssc.stop() intercept[SparkException] { ssc.start() // start after stop should throw exception @@ -157,6 +162,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.stop() } + test("stop(stopSparkContext=true) after stop(stopSparkContext=false)") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register() + ssc.stop(stopSparkContext = false) + assert(ssc.sc.makeRDD(1 to 100).collect().size === 100) + ssc.stop(stopSparkContext = true) + // Check that the SparkContext is actually stopped: + intercept[Exception] { + ssc.sc.makeRDD(1 to 100).collect() + } + } + test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.cleaner.ttl", "3600") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala new file mode 100644 index 0000000000000..d2b983c4b4d1a --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.rdd + +import java.io.File + +import scala.util.Random + +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} + +class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() + + var sparkContext: SparkContext = null + var blockManager: BlockManager = null + var dir: File = null + + override def beforeAll(): Unit = { + sparkContext = new SparkContext(conf) + blockManager = sparkContext.env.blockManager + dir = Files.createTempDir() + } + + override def afterAll(): Unit = { + // Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests. + sparkContext.stop() + dir.delete() + System.clearProperty("spark.driver.port") + } + + test("Read data available in block manager and write ahead log") { + testRDD(5, 5) + } + + test("Read data available only in block manager, not in write ahead log") { + testRDD(5, 0) + } + + test("Read data available only in write ahead log, not in block manager") { + testRDD(0, 5) + } + + test("Read data available only in write ahead log, and test storing in block manager") { + testRDD(0, 5, testStoreInBM = true) + } + + test("Read data with partially available in block manager, and rest in write ahead log") { + testRDD(3, 2) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * @param numPartitionsInBM Number of partitions to write to the Block Manager + * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log + * @param testStoreInBM Test whether blocks read from log are stored back into block manager + */ + private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + val numBlocks = numPartitionsInBM + numPartitionsInWAL + val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) + + // Put the necessary blocks in the block manager + val blockIds = Array.fill(numBlocks)(StreamBlockId(Random.nextInt(), Random.nextInt())) + data.zip(blockIds).take(numPartitionsInBM).foreach { case(block, blockId) => + blockManager.putIterator(blockId, block.iterator, StorageLevel.MEMORY_ONLY_SER) + } + + // Generate write ahead log segments + val segments = generateFakeSegments(numPartitionsInBM) ++ + writeLogSegments(data.takeRight(numPartitionsInWAL), blockIds.takeRight(numPartitionsInWAL)) + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + require( + blockIds.takeRight(numPartitionsInWAL).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + require( + segments.takeRight(numPartitionsInWAL).forall(s => + new File(s.path.stripPrefix("file://")).exists()), + "Expected blocks not in write ahead log" + ) + require( + segments.take(numPartitionsInBM).forall(s => + !new File(s.path.stripPrefix("file://")).exists()), + "Unexpected blocks in write ahead log" + ) + + // Create the RDD and verify whether the returned data is correct + val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, + segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) + assert(rdd.collect() === data.flatten) + + if (testStoreInBM) { + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, + segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) + assert(rdd2.collect() === data.flatten) + assert( + blockIds.forall(blockManager.get(_).nonEmpty), + "All blocks not found in block manager" + ) + } + } + + private def writeLogSegments( + blockData: Seq[Seq[String]], + blockIds: Seq[BlockId] + ): Seq[WriteAheadLogFileSegment] = { + require(blockData.size === blockIds.size) + val writer = new WriteAheadLogWriter(new File(dir, Random.nextString(10)).toString, hadoopConf) + val segments = blockData.zip(blockIds).map { case (data, id) => + writer.write(blockManager.dataSerialize(id, data.iterator)) + } + writer.close() + segments + } + + private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { + Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 5eba93c208c50..1956a4f1db90a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -58,7 +58,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { test("WriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() val segments = writeDataUsingWriter(testFile, dataToWrite) - val writtenData = readDataManually(testFile, segments) + val writtenData = readDataManually(segments) assert(writtenData === dataToWrite) } @@ -67,7 +67,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { val writer = new WriteAheadLogWriter(testFile, hadoopConf) dataToWrite.foreach { data => val segment = writer.write(stringToByteBuffer(data)) - val dataRead = readDataManually(testFile, Seq(segment)).head + val dataRead = readDataManually(Seq(segment)).head assert(data === dataRead) } writer.close() @@ -281,14 +281,20 @@ object WriteAheadLogSuite { } /** Read data from a segments of a log file directly and return the list of byte buffers.*/ - def readDataManually(file: String, segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { - val reader = HdfsUtils.getInputStream(file, hadoopConf) - segments.map { x => - reader.seek(x.offset) - val data = new Array[Byte](x.length) - reader.readInt() - reader.readFully(data) - Utils.deserialize[String](data) + def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { + segments.map { segment => + val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) + try { + reader.seek(segment.offset) + val bytes = new Array[Byte](segment.length) + reader.readInt() + reader.readFully(bytes) + val data = Utils.deserialize[String](bytes) + reader.close() + data + } finally { + reader.close() + } } } @@ -335,9 +341,11 @@ object WriteAheadLogSuite { val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { - fileSystem.listStatus(logDirectoryPath).map { - _.getPath.toString.stripPrefix("file:") - }.sorted + fileSystem.listStatus(logDirectoryPath).map { _.getPath() }.sortBy { + _.getName().split("-")(1).toLong + }.map { + _.toString.stripPrefix("file:") + } } else { Seq.empty } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7ee4b5c842df1..7023a1170654f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.util.JavaUtils @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -90,6 +91,22 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 5c54e3400301a..8b32c76d14037 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import org.apache.spark.util.{MemoryParam, IntParam} +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import collection.mutable.ArrayBuffer class ApplicationMasterArguments(val args: Array[String]) { @@ -26,7 +27,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 var executorCores = 1 - var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + var numExecutors = DEFAULT_NUMBER_EXECUTORS parseArgs(args.toList) @@ -81,7 +82,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --args ARGS Arguments to be passed to your application's main class. - | Mutliple invocations are possible, each will be passed in order. + | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index a12f82d2fbe70..4d859450efc63 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf -import org.apache.spark.util.{Utils, IntParam, MemoryParam} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { @@ -33,23 +33,25 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 // MB var executorCores = 1 - var numExecutors = 2 + var numExecutors = DEFAULT_NUMBER_EXECUTORS var amQueue = sparkConf.get("spark.yarn.queue", "default") var amMemory: Int = 512 // MB var appName: String = "Spark" var priority = 0 - parseArgs(args.toList) - loadEnvironmentArgs() - // Additional memory to allocate to containers // For now, use driver's memory overhead as our AM container's memory overhead - val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN)) - val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + private val isDynamicAllocationEnabled = + sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + + parseArgs(args.toList) + loadEnvironmentArgs() validateArgs() /** Load any default arguments provided through environment variables and Spark properties. */ @@ -64,6 +66,15 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) .orNull + // If dynamic allocation is enabled, start at the max number of executors + if (isDynamicAllocationEnabled) { + val maxExecutorsConf = "spark.dynamicAllocation.maxExecutors" + if (!sparkConf.contains(maxExecutorsConf)) { + throw new IllegalArgumentException( + s"$maxExecutorsConf must be set if dynamic allocation is enabled!") + } + numExecutors = sparkConf.get(maxExecutorsConf).toInt + } } /** @@ -113,6 +124,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } + // Dynamic allocation is not compatible with this option + if (isDynamicAllocationEnabled) { + throw new IllegalArgumentException("Explicitly setting the number " + + "of executors is not compatible with spark.dynamicAllocation.enabled!") + } numExecutors = value args = tail diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7ae8ef237ff89..b32e15738f28b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger +import java.util.regex.Pattern import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -63,6 +64,8 @@ private[yarn] abstract class YarnAllocator( securityMgr: SecurityManager) extends Logging { + import YarnAllocator._ + // These three are locked on allocatedHostToContainersMap. Complementary data structures // allocatedHostToContainersMap : containers which are running : host, Set // allocatedContainerToHostMap: container to host mapping. @@ -375,12 +378,22 @@ private[yarn] abstract class YarnAllocator( logInfo("Completed container %s (state: %s, exit status: %s)".format( containerId, completedContainer.getState, - completedContainer.getExitStatus())) + completedContainer.getExitStatus)) // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus() != 0) { - logInfo("Container marked as failed: " + containerId) + if (completedContainer.getExitStatus == -103) { // vmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus != 0) { + logInfo("Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) numExecutorsFailed.incrementAndGet() } } @@ -508,3 +521,18 @@ private[yarn] abstract class YarnAllocator( } } + +private object YarnAllocator { + val MEM_REGEX = "[0-9.]+ [KMG]B" + val PMEM_EXCEEDED_PATTERN = + Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") + val VMEM_EXCEEDED_PATTERN = + Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + + def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { + val matcher = pattern.matcher(diagnostics) + val diag = if (matcher.find()) " " + matcher.group() + "." else "" + ("Container killed by YARN for exceeding memory limits." + diag + + " Consider boosting spark.yarn.executor.memoryOverhead.") + } +} diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index e1e0144f46fe9..7d453ecb7983c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -93,6 +93,8 @@ object YarnSparkHadoopUtil { val ANY_HOST = "*" + val DEFAULT_NUMBER_EXECUTORS = 2 + // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = 1 diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f6f6dc52433e5..2923e6729cd6b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -33,7 +33,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - private var stopping: Boolean = false + @volatile private var stopping: Boolean = false /** * Create a Yarn client to submit an application to the ResourceManager. diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index a96a54f66824c..b1de81e6a8b0f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.ApplicationMasterArguments +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.IntParam @@ -29,7 +29,7 @@ private[spark] class YarnClusterSchedulerBackend( override def start() { super.start() - totalExpectedExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + totalExpectedExecutors = DEFAULT_NUMBER_EXECUTORS if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) .getOrElse(totalExpectedExecutors) diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala new file mode 100644 index 0000000000000..8d184a09d64cc --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.scalatest.FunSuite + +class YarnAllocatorSuite extends FunSuite { + test("memory exceeded diagnostic regexes") { + val diagnostics = + "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + + "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + + "5.8 GB of 4.2 GB virtual memory used. Killing container." + val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) + val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) + assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) + assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) + } +} diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b5a92d87d722..fdd3c2300fa78 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -89,6 +90,22 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) }