From 38112905bc3b33f2ae75274afba1c30e116f6e46 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 10 Jun 2015 13:17:29 -0700 Subject: [PATCH 01/18] [SPARK-5479] [YARN] Handle --py-files correctly in YARN. The bug description is a little misleading: the actual issue is that .py files are not handled correctly when distributed by YARN. They're added to "spark.submit.pyFiles", which, when processed by context.py, explicitly whitelists certain extensions (see PACKAGE_EXTENSIONS), and that does not include .py files. On top of that, archives were not handled at all! They made it to the driver's python path, but never made it to executors, since the mechanism used to propagate their location (spark.submit.pyFiles) only works on the driver side. So, instead, ignore "spark.submit.pyFiles" and just build PYTHONPATH correctly for both driver and executors. Individual .py files are placed in a subdirectory of the container's local dir in the cluster, which is then added to the python path. Archives are added directly. The change, as a side effect, ends up solving the symptom described in the bug. The issue was not that the files were not being distributed, but that they were never made visible to the python application running under Spark. Also included is a proper unit test for running python on YARN, which broke in several different ways with the previous code. A short walk around of the changes: - SparkSubmit does not try to be smart about how YARN handles python files anymore. It just passes down the configs to the YARN client code. - The YARN client distributes python files and archives differently, placing the files in a subdirectory. - The YARN client now sets PYTHONPATH for the processes it launches; to properly handle different locations, it uses YARN's support for embedding env variables, so to avoid YARN expanding those at the wrong time, SparkConf is now propagated to the AM using a conf file instead of command line options. - Because the Client initialization code is a maze of implicit dependencies, some code needed to be moved around to make sure all needed state was available when the code ran. - The pyspark tests in YarnClusterSuite now actually distribute and try to use both a python file and an archive containing a different python module. Also added a yarn-client tests for completeness. - I cleaned up some of the code around distributing files to YARN, to avoid adding more copied & pasted code to handle the new files being distributed. Author: Marcelo Vanzin Closes #6360 from vanzin/SPARK-5479 and squashes the following commits: bcaf7e6 [Marcelo Vanzin] Feedback. c47501f [Marcelo Vanzin] Fix yarn-client mode. 46b1d0c [Marcelo Vanzin] Merge branch 'master' into SPARK-5479 c743778 [Marcelo Vanzin] Only pyspark cares about python archives. c8e5a82 [Marcelo Vanzin] Actually run pyspark in client mode. 705571d [Marcelo Vanzin] Move some code to the YARN module. 1dd4d0c [Marcelo Vanzin] Review feedback. 71ee736 [Marcelo Vanzin] Merge branch 'master' into SPARK-5479 220358b [Marcelo Vanzin] Scalastyle. cdbb990 [Marcelo Vanzin] Merge branch 'master' into SPARK-5479 7fe3cd4 [Marcelo Vanzin] No need to distribute primary file to executors. 09045f1 [Marcelo Vanzin] Style. 943cbf4 [Marcelo Vanzin] [SPARK-5479] [yarn] Handle --py-files correctly in YARN. --- .../org/apache/spark/deploy/SparkSubmit.scala | 77 +---- .../spark/deploy/yarn/ApplicationMaster.scala | 20 +- .../yarn/ApplicationMasterArguments.scala | 12 +- .../org/apache/spark/deploy/yarn/Client.scala | 295 +++++++++++------- .../spark/deploy/yarn/ClientArguments.scala | 4 +- .../cluster/YarnClientSchedulerBackend.scala | 5 +- .../spark/deploy/yarn/ClientSuite.scala | 4 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 61 ++-- 8 files changed, 270 insertions(+), 208 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a0eae774268ed..b8978e25a02d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -324,55 +324,20 @@ object SparkSubmit { // Usage: PythonAppRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs - args.files = mergeFileLists(args.files, args.primaryResource) + if (clusterManager != YARN) { + // The YARN backend distributes the primary file differently, so don't merge it. + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + if (clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) } - args.files = mergeFileLists(args.files, args.pyFiles) if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } } - // In yarn mode for a python app, add pyspark archives to files - // that can be distributed with the job - if (args.isPython && clusterManager == YARN) { - var pyArchives: String = null - val pyArchivesEnvOpt = sys.env.get("PYSPARK_ARCHIVES_PATH") - if (pyArchivesEnvOpt.isDefined) { - pyArchives = pyArchivesEnvOpt.get - } else { - if (!sys.env.contains("SPARK_HOME")) { - printErrorAndExit("SPARK_HOME does not exist for python application in yarn mode.") - } - val pythonPath = new ArrayBuffer[String] - for (sparkHome <- sys.env.get("SPARK_HOME")) { - val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator) - val pyArchivesFile = new File(pyLibPath, "pyspark.zip") - if (!pyArchivesFile.exists()) { - printErrorAndExit("pyspark.zip does not exist for python application in yarn mode.") - } - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") - if (!py4jFile.exists()) { - printErrorAndExit("py4j-0.8.2.1-src.zip does not exist for python application " + - "in yarn mode.") - } - pythonPath += pyArchivesFile.getAbsolutePath() - pythonPath += py4jFile.getAbsolutePath() - } - pyArchives = pythonPath.mkString(",") - } - - pyArchives = pyArchives.split(",").map { localPath => - val localURI = Utils.resolveURI(localPath) - if (localURI.getScheme != "local") { - args.files = mergeFileLists(args.files, localURI.toString) - new Path(localPath).getName - } else { - localURI.getPath - } - }.mkString(File.pathSeparator) - sysProps("spark.submit.pyArchives") = pyArchives - } - // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -386,19 +351,10 @@ object SparkSubmit { } } - if (isYarnCluster) { - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) - } - + if (isYarnCluster && args.isR) { // In yarn-cluster mode for a R app, add primary resource to files // that can be distributed with the job - if (args.isR) { - args.files = mergeFileLists(args.files, args.primaryResource) - } + args.files = mergeFileLists(args.files, args.primaryResource) } // Special flag to avoid deprecation warnings at the client @@ -515,17 +471,18 @@ object SparkSubmit { } } + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN && args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { - val mainPyFile = new Path(args.primaryResource).getName - childArgs += ("--primary-py-file", mainPyFile) + childArgs += ("--primary-py-file", args.primaryResource) if (args.pyFiles != null) { - // These files will be distributed to each machine's working directory, so strip the - // path prefix - val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") - childArgs += ("--py-files", pyFilesNames) + childArgs += ("--py-files", args.pyFiles) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") } else if (args.isR) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 002d7b6eaf498..83dafa4a125d2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -46,6 +46,14 @@ private[spark] class ApplicationMaster( client: YarnRMClient) extends Logging { + // Load the properties file with the Spark configuration and set entries as system properties, + // so that user code run inside the AM also has access to them. + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sys.props(k) = v + } + } + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -490,9 +498,11 @@ private[spark] class ApplicationMaster( new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } + var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - System.setProperty("spark.submit.pyFiles", - PythonRunner.formatPaths(args.pyFiles).mkString(",")) + // When running pyspark, the app is run using PythonRunner. The second argument is the list + // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty. + userArgs = Seq(args.primaryPyFile, "") ++ userArgs } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { // TODO(davies): add R dependencies here @@ -503,9 +513,7 @@ private[spark] class ApplicationMaster( val userThread = new Thread { override def run() { try { - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) + mainMethod.invoke(null, userArgs.toArray) finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) logDebug("Done running users class") } catch { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index ae6dc1094d724..68e9f6b4db7f4 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) { var userClass: String = null var primaryPyFile: String = null var primaryRFile: String = null - var pyFiles: String = null - var userArgs: Seq[String] = Seq[String]() + var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS + var propertiesFile: String = null parseArgs(args.toList) @@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--py-files") :: value :: tail => - pyFiles = value - args = tail - case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -79,6 +75,10 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--properties-file") :: value :: tail => + propertiesFile = value + args = tail + case _ => printUsageAndExit(1, args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f4d43214b08ca..ec9402afff329 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,11 +17,12 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, + OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction -import java.util.UUID +import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ @@ -29,6 +30,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files @@ -247,7 +249,9 @@ private[spark] class Client( * This is used for setting up a container launch context for our ApplicationMaster. * Exposed for testing. */ - def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources( + appStagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, LocalResource] = { logInfo("Preparing resources for our AM container") // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. @@ -277,20 +281,6 @@ private[spark] class Client( "for alternatives.") } - // If we passed in a keytab, make sure we copy the keytab to the staging directory on - // HDFS, and setup the relevant environment vars, so the AM can login again. - if (loginFromKeytab) { - logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + - " via the YARN Secure Distributed Cache.") - val localUri = new URI(args.keytab) - val localPath = getQualifiedLocalPath(localUri, hadoopConf) - val destinationPath = copyFileToRemote(dst, localPath, replication) - val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) - distCacheMgr.addResource( - destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, - sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -302,6 +292,57 @@ private[spark] class Client( } } + /** + * Distribute a file to the cluster. + * + * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied + * to HDFS (if not already there) and added to the application's distributed cache. + * + * @param path URI of the file to distribute. + * @param resType Type of resource being distributed. + * @param destName Name of the file in the distributed cache. + * @param targetDir Subdirectory where to place the file. + * @param appMasterOnly Whether to distribute only to the AM. + * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the + * localized path for non-local paths, or the input `path` for local paths. + * The localized path will be null if the URI has already been added to the cache. + */ + def distribute( + path: String, + resType: LocalResourceType = LocalResourceType.FILE, + destName: Option[String] = None, + targetDir: Option[String] = None, + appMasterOnly: Boolean = false): (Boolean, String) = { + val localURI = new URI(path.trim()) + if (localURI.getScheme != LOCAL_SCHEME) { + if (addDistributedUri(localURI)) { + val localPath = getQualifiedLocalPath(localURI, hadoopConf) + val linkname = targetDir.map(_ + "/").getOrElse("") + + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache, + appMasterOnly = appMasterOnly) + (false, linkname) + } else { + (false, null) + } + } else { + (true, path.trim()) + } + } + + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val (_, localizedPath) = distribute(args.keytab, + destName = Some(sparkConf.get("spark.yarn.keytab")), + appMasterOnly = true) + require(localizedPath != null, "Keytab file already distributed.") + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -314,33 +355,18 @@ private[spark] class Client( (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, _localPath, confKey) => - val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (!localPath.isEmpty()) { - val localURI = new URI(localPath) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) - } - } else if (confKey != null) { + ).foreach { case (destName, path, confKey) => + if (path != null && !path.trim().isEmpty()) { + val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) + if (isLocal && confKey != null) { + require(localizedPath != null, s"Path $path already distributed.") // If the resource is intended for local use only, handle this downstream // by setting the appropriate property - sparkConf.set(confKey, localPath) + sparkConf.set(confKey, localizedPath) } } } - createConfArchive().foreach { file => - require(addDistributedUri(file.toURI())) - val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) - distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, - LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) - } - /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -356,21 +382,10 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { flist.split(',').foreach { file => - val localURI = new URI(file.trim()) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname - } - } - } else if (addToClasspath) { - // Resource is intended for local use only and should be added to the class path - cachedSecondaryJarLinks += file.trim() + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -379,11 +394,31 @@ private[spark] class Client( sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) } + if (isClusterMode && args.primaryPyFile != null) { + distribute(args.primaryPyFile, appMasterOnly = true) + } + + pySparkArchives.foreach { f => distribute(f) } + + // The python files list needs to be treated especially. All files that are not an + // archive need to be placed in a subdirectory that will be added to PYTHONPATH. + args.pyFiles.foreach { f => + val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None + distribute(f, targetDir = targetDir) + } + + // Distribute an archive with Hadoop and Spark configuration for the AM. + val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_CONF_DIR), + appMasterOnly = true) + require(confLocalizedPath != null) + localResources } /** - * Create an archive with the Hadoop config files for distribution. + * Create an archive with the config files for distribution. * * These are only used by the AM, since executors will use the configuration object broadcast by * the driver. The files are zipped and added to the job as an archive, so that YARN will explode @@ -395,8 +430,11 @@ private[spark] class Client( * * Currently this makes a shallow copy of the conf directory. If there are cases where a * Hadoop config directory contains subdirectories, this code will have to be fixed. + * + * The archive also contains some Spark configuration. Namely, it saves the contents of + * SparkConf in a file to be loaded by the AM process. */ - private def createConfArchive(): Option[File] = { + private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => @@ -411,28 +449,32 @@ private[spark] class Client( } } - if (!hadoopConfFiles.isEmpty) { - val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", - new File(Utils.getLocalDir(sparkConf))) + val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + val confStream = new ZipOutputStream(new FileOutputStream(confArchive)) - val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) - try { - hadoopConfStream.setLevel(0) - hadoopConfFiles.foreach { case (name, file) => - if (file.canRead()) { - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() - } + try { + confStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + if (file.canRead()) { + confStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, confStream) + confStream.closeEntry() } - } finally { - hadoopConfStream.close() } - Some(hadoopConfArchive) - } else { - None + // Save Spark configuration to a file in the archive. + val props = new Properties() + sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) + val writer = new OutputStreamWriter(confStream, UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + confStream.closeEntry() + } finally { + confStream.close() } + confArchive } /** @@ -460,7 +502,9 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + private def setupLaunchEnv( + stagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") @@ -478,9 +522,6 @@ private[spark] class Client( val renewalInterval = getTokenRenewalInterval(stagingDirPath) sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) } - // Set the environment variables to be passed on to the executors. - distCacheMgr.setDistFilesEnv(env) - distCacheMgr.setDistArchivesEnv(env) // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -497,15 +538,32 @@ private[spark] class Client( env("SPARK_YARN_USER_ENV") = userEnvs } - // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH - // that can be passed on to the ApplicationMaster and the executors. - if (sparkConf.contains("spark.submit.pyArchives")) { - var pythonPath = sparkConf.get("spark.submit.pyArchives") - if (env.contains("PYTHONPATH")) { - pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator) + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH + // of the container processes too. Add all non-.py files directly to PYTHONPATH. + // + // NOTE: the code currently does not handle .py files defined with a "local:" scheme. + val pythonPath = new ListBuffer[String]() + val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + if (pyFiles.nonEmpty) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_PYTHON_DIR) + } + (pySparkArchives ++ pyArchives).foreach { path => + val uri = new URI(path) + if (uri.getScheme != LOCAL_SCHEME) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + new Path(path).getName()) + } else { + pythonPath += uri.getPath() } - env("PYTHONPATH") = pythonPath - sparkConf.setExecutorEnv("PYTHONPATH", pythonPath) + } + + // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. + if (pythonPath.nonEmpty) { + val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + .mkString(YarnSparkHadoopUtil.getClassPathSeparator) + env("PYTHONPATH") = pythonPathStr + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to @@ -555,8 +613,19 @@ private[spark] class Client( logInfo("Setting up container launch context for our AM") val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(appStagingDir) + val pySparkArchives = + if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + findPySparkArchives() + } else { + Nil + } + val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives) + val localResources = prepareLocalResources(appStagingDir, pySparkArchives) + + // Set the environment variables to be passed on to the executors. + distCacheMgr.setDistFilesEnv(launchEnv) + distCacheMgr.setDistArchivesEnv(launchEnv) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) amContainer.setEnvironment(launchEnv) @@ -596,13 +665,6 @@ private[spark] class Client( javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // Forward the Spark configuration to the application master / executors. - // TODO: it might be nicer to pass these as an internal environment variable rather than - // as Java options, due to complications with string parsing of nested quotes. - for ((k, v) <- sparkConf.getAll) { - javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") - } - // Include driver-specific java options if we are launching a driver if (isClusterMode) { val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") @@ -655,14 +717,8 @@ private[spark] class Client( Nil } val primaryPyFile = - if (args.primaryPyFile != null) { - Seq("--primary-py-file", args.primaryPyFile) - } else { - Nil - } - val pyFiles = - if (args.pyFiles != null) { - Seq("--py-files", args.pyFiles) + if (isClusterMode && args.primaryPyFile != null) { + Seq("--primary-py-file", new Path(args.primaryPyFile).getName()) } else { Nil } @@ -678,9 +734,6 @@ private[spark] class Client( } else { Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } - if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs - } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs } @@ -688,11 +741,13 @@ private[spark] class Client( Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--num-executors ", args.numExecutors.toString, + "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( @@ -857,6 +912,22 @@ private[spark] class Client( } } } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + require(py4jFile.exists(), + "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) + } + } + } object Client extends Logging { @@ -907,8 +978,14 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" - // Subdirectory where the user's hadoop config files will be placed. - val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + // Subdirectory where the user's Spark and Hadoop config files will be placed. + val LOCALIZED_CONF_DIR = "__spark_conf__" + + // Name of the file in the conf archive containing Spark configuration. + val SPARK_CONF_FILE = "__spark_conf__.properties" + + // Subdirectory where the user's python files (not archives) will be placed. + val LOCALIZED_PYTHON_DIR = "__pyfiles__" /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -1033,7 +1110,7 @@ object Client extends Logging { if (isAM) { addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_HADOOP_CONF_DIR, env) + LOCALIZED_CONF_DIR, env) } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 9c7b1b3988082..35e990602a6cf 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: String = null + var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() @@ -228,7 +228,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) args = tail case ("--py-files") :: value :: tail => - pyFiles = value + pyFiles = value.split(",") args = tail case ("--files") :: value :: tail => diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 99c05329b4d73..1c8d7ec57635f 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -76,7 +76,8 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--py-files", null, "spark.submit.pyFiles") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( @@ -86,7 +87,7 @@ private[spark] class YarnClientSchedulerBackend( optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (System.getenv(envVar) != null) { + } else if (envVar != null && System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 01d33c9ce9297..4ec976aa31387 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -113,7 +113,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { Environment.PWD.$() } cp should contain(pwdVar) - cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -129,7 +129,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { val tempDir = Utils.createTempDir() try { - client.prepareLocalResources(tempDir.getAbsolutePath()) + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 93d587d0cb36a..a0f25ba450068 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -56,6 +56,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher """.stripMargin private val TEST_PYFILE = """ + |import mod1, mod2 |import sys |from operator import add | @@ -67,7 +68,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | sc = SparkContext(conf=SparkConf()) | status = open(sys.argv[1],'w') | result = "failure" - | rdd = sc.parallelize(range(10)) + | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func()) | cnt = rdd.count() | if cnt == 10: | result = "success" @@ -76,6 +77,11 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | sc.stop() """.stripMargin + private val TEST_PYMODULE = """ + |def func(): + | return 42 + """.stripMargin + private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ @@ -124,7 +130,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) assert(hadoopConfDir.mkdir()) File.createTempFile("token", ".txt", hadoopConfDir) } @@ -151,26 +157,12 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } } - // Enable this once fix SPARK-6700 - test("run Python application in yarn-cluster mode") { - val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) - val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, UTF_8) - var result = File.createTempFile("result", null, tempDir) + test("run Python application in yarn-client mode") { + testPySpark(true) + } - // The sbt assembly does not include pyspark / py4j python dependencies, so we need to - // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. - val sparkHome = sys.props("spark.test.home") - val extraConf = Map( - "spark.executorEnv.SPARK_HOME" -> sparkHome, - "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - - runSpark(false, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), - appArgs = Seq(result.getAbsolutePath()), - extraConf = extraConf) - checkResult(result) + test("run Python application in yarn-cluster mode") { + testPySpark(false) } test("user class path first in client mode") { @@ -188,6 +180,33 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(result) } + private def testPySpark(clientMode: Boolean): Unit = { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFiles), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + private def testUseClassPathFirst(clientMode: Boolean): Unit = { // Create a jar file that contains a different version of "test.resource". val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) From 30ebf1a233295539c2455bd838bae7315711e1e2 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 10 Jun 2015 13:18:48 -0700 Subject: [PATCH 02/18] [SPARK-8282] [SPARKR] Make number of threads used in RBackend configurable Read number of threads for RBackend from configuration. [SPARK-8282] #comment Linking with JIRA Author: Hossein Closes #6730 from falaki/SPARK-8282 and squashes the following commits: 33b3d98 [Hossein] Documented new config parameter 70f2a9c [Hossein] Fixing import ec44225 [Hossein] Read number of threads for RBackend from configuration --- .../main/scala/org/apache/spark/api/r/RBackend.scala | 5 +++-- docs/configuration.md | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index d24c650d37bb0..1a5f2bca26c2b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -29,7 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} /** * Netty-based backend server that is used to communicate between R and Java. @@ -41,7 +41,8 @@ private[spark] class RBackend { private[this] var bossGroup: EventLoopGroup = null def init(): Int = { - bossGroup = new NioEventLoopGroup(2) + val conf = new SparkConf() + bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) val workerGroup = bossGroup val handler = new RBackendHandler(this) diff --git a/docs/configuration.md b/docs/configuration.md index 3960e7e78bde1..95a322f79b40b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1495,6 +1495,18 @@ Apart from these, the following properties are also available, and may be useful +#### SparkR + + + + + + + +
Property NameDefaultMeaning
spark.r.numRBackendThreads2 + Number of threads used by RBackend to handle RPC calls from SparkR package. +
+ #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: From 19e30b48f3c6d0b72871d3e15b9564c1b2822700 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 10 Jun 2015 13:21:01 -0700 Subject: [PATCH 03/18] [SPARK-7756] CORE RDDOperationScope fix for IBM Java IBM Java has an extra method when we do getStackTrace(): this is "getStackTraceImpl", a native method. This causes two tests to fail within "DStreamScopeSuite" when running with IBM Java. Instead of "map" or "filter" being the method names found, "getStackTrace" is returned. This commit addresses such an issue by using dropWhile. Given that our current method is withScope, we look for the next method that isn't ours: we don't care about methods that come before us in the stack trace: e.g. getStackTrace (regardless of how many levels this might go). IBM: java.lang.Thread.getStackTraceImpl(Native Method) java.lang.Thread.getStackTrace(Thread.java:1117) org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:104) Oracle: PRINTING STACKTRACE!!! java.lang.Thread.getStackTrace(Thread.java:1552) org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:106) I've tested this with Oracle and IBM Java, no side effects for other tests introduced. Author: Adam Roberts Author: a-roberts Closes #6740 from a-roberts/RDDScopeStackCrawlFix and squashes the following commits: 13ce390 [Adam Roberts] Ensure consistency with String equality checking a4fc0e0 [a-roberts] Update RDDOperationScope.scala --- .../scala/org/apache/spark/rdd/RDDOperationScope.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 6b09dfafc889c..44667281c1063 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -95,10 +95,9 @@ private[spark] object RDDOperationScope extends Logging { private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { - val stackTrace = Thread.currentThread.getStackTrace().tail // ignore "Thread#getStackTrace" - val ourMethodName = stackTrace(1).getMethodName // i.e. withScope - // Climb upwards to find the first method that's called something different - val callerMethodName = stackTrace + val ourMethodName = "withScope" + val callerMethodName = Thread.currentThread.getStackTrace() + .dropWhile(_.getMethodName != ourMethodName) .find(_.getMethodName != ourMethodName) .map(_.getMethodName) .getOrElse { From e90c9d92d9a86e9960c10a5c043f3c02f6c636f9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 10 Jun 2015 13:22:52 -0700 Subject: [PATCH 04/18] [SPARK-7527] [CORE] Fix createNullValue to return the correct null values and REPL mode detection The root cause of SPARK-7527 is `createNullValue` returns an incompatible value `Byte(0)` for `char` and `boolean`. This PR fixes it and corrects the class name of the main class, and also adds an unit test to demonstrate it. Author: zsxwing Closes #6735 from zsxwing/SPARK-7527 and squashes the following commits: bbdb271 [zsxwing] Use pattern match in createNullValue b0a0e7e [zsxwing] Remove the noisy in the test output 903e269 [zsxwing] Remove the code for Utils.isInInterpreter == false 5f92dc1 [zsxwing] Fix createNullValue to return the correct null values and REPL mode detection --- .../apache/spark/util/ClosureCleaner.scala | 40 ++++++++--------- .../scala/org/apache/spark/util/Utils.scala | 9 +--- .../spark/util/ClosureCleanerSuite.scala | 44 +++++++++++++++++++ 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 6f2966bd4fd31..305de4c75539d 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -109,7 +109,14 @@ private[spark] object ClosureCleaner extends Logging { private def createNullValue(cls: Class[_]): AnyRef = { if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type + cls match { + case java.lang.Boolean.TYPE => new java.lang.Boolean(false) + case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Void.TYPE => + // This should not happen because `Foo(void x) {}` does not compile. + throw new IllegalStateException("Unexpected void parameter in constructor") + case _ => new java.lang.Byte(0: Byte) + } } else { null } @@ -319,28 +326,17 @@ private[spark] object ClosureCleaner extends Logging { private def instantiateClass( cls: Class[_], enclosingObject: AnyRef): AnyRef = { - if (!Utils.isInInterpreter) { - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue).toArray - if (enclosingObject != null) { - params(0) = enclosingObject // First param is always enclosing object - } - return cons.newInstance(params: _*).asInstanceOf[AnyRef] - } else { - // Use reflection to instantiate object without calling constructor - val rf = sun.reflect.ReflectionFactory.getReflectionFactory() - val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() - val newCtor = rf.newConstructorForSerialization(cls, parentCtor) - val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (enclosingObject != null) { - val field = cls.getDeclaredField("$outer") - field.setAccessible(true) - field.set(obj, enclosingObject) - } - obj + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory() + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef] + if (enclosingObject != null) { + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, enclosingObject) } + obj } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 153ece6224a6d..19157af5b6f4d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1804,15 +1804,10 @@ private[spark] object Utils extends Logging { lazy val isInInterpreter: Boolean = { try { - val interpClass = classForName("spark.repl.Main") + val interpClass = classForName("org.apache.spark.repl.Main") interpClass.getMethod("interp").invoke(null) != null } catch { - // Returning true seems to be a mistake. - // Currently changing it to false causes tests failures in Streaming. - // For a more detailed discussion, please, refer to - // https://github.com/apache/spark/pull/5835#issuecomment-101042271 and subsequent comments. - // Addressing this changed is tracked as https://issues.apache.org/jira/browse/SPARK-7527 - case _: ClassNotFoundException => true + case _: ClassNotFoundException => false } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 70cd27b04347d..1053c6caf7718 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -121,6 +121,10 @@ class ClosureCleanerSuite extends SparkFunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) } } } + + test("createNullValue") { + new TestCreateNullValue().run() + } } // A non-serializable class we create in closures to make sure that we aren't @@ -350,3 +354,43 @@ private object TestUserClosuresActuallyCleaned { ) } } + +class TestCreateNullValue { + + var x = 5 + + def getX: Int = x + + def run(): Unit = { + val bo: Boolean = true + val c: Char = '1' + val b: Byte = 1 + val s: Short = 1 + val i: Int = 1 + val l: Long = 1 + val f: Float = 1 + val d: Double = 1 + + // Bring in all primitive types into the closure such that they become + // parameters of the closure constructor. This allows us to test whether + // null values are created correctly for each type. + val nestedClosure = () => { + if (s.toString == "123") { // Don't really output them to avoid noisy + println(bo) + println(c) + println(b) + println(s) + println(i) + println(l) + println(f) + println(d) + } + + val closure = () => { + println(getX) + } + ClosureCleaner.clean(closure) + } + nestedClosure() + } +} From 80043e9e761c44ce2c3a432dcd1989be573f8bb4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 10 Jun 2015 13:25:59 -0700 Subject: [PATCH 05/18] [SPARK-7261] [CORE] Change default log level to WARN in the REPL 1. Add `log4j-defaults-repl.properties` that has log level WARN. 2. When logging is initialized, check whether inside the REPL. If so, use `log4j-defaults-repl.properties`. 3. Print the following information if using `log4j-defaults-repl.properties`: ``` Using Spark's repl log4j profile: org/apache/spark/log4j-defaults-repl.properties To adjust logging level use sc.setLogLevel("INFO") ``` Author: zsxwing Closes #6734 from zsxwing/log4j-repl and squashes the following commits: 3835eff [zsxwing] Change default log level to WARN in the REPL --- .rat-excludes | 1 + .../spark/log4j-defaults-repl.properties | 12 +++++++++ .../main/scala/org/apache/spark/Logging.scala | 26 ++++++++++++++----- 3 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties diff --git a/.rat-excludes b/.rat-excludes index 994c7e86f8a91..aa008e6e920f5 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -28,6 +28,7 @@ spark-env.sh spark-env.cmd spark-env.sh.template log4j-defaults.properties +log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties new file mode 100644 index 0000000000000..b146f8a784127 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 419d093d55643..7fcb7830e7b0b 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,13 +121,25 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (Utils.isInInterpreter) { + val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" + Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") + System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") + case None => + System.err.println(s"Spark was unable to load $replDefaultLogProps") + } + } else { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } } From cb871c44c38a4c1575ed076389f14641afafad7d Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Wed, 10 Jun 2015 13:30:16 -0700 Subject: [PATCH 06/18] [SPARK-8290] spark class command builder need read SPARK_JAVA_OPTS and SPARK_DRIVER_MEMORY properly SPARK_JAVA_OPTS was missed in reconstructing the launcher part, we should add it back so process launched by spark-class could read it properly. And so does `SPARK_DRIVER_MEMORY`. The missing part is [here](https://github.com/apache/spark/blob/1c30afdf94b27e1ad65df0735575306e65d148a1/bin/spark-class#L97). Author: WangTaoTheTonic Author: Tao Wang Closes #6741 from WangTaoTheTonic/SPARK-8290 and squashes the following commits: bd89f0f [Tao Wang] make sure the memory setting is right too e313520 [WangTaoTheTonic] spark class command builder need read SPARK_JAVA_OPTS --- .../org/apache/spark/launcher/SparkClassCommandBuilder.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index d80abf2a8676e..de85720febf23 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -93,6 +93,9 @@ public List buildCommand(Map env) throws IOException { toolsDir.getAbsolutePath(), className); javaOptsKeys.add("SPARK_JAVA_OPTS"); + } else { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + memKey = "SPARK_DRIVER_MEMORY"; } List cmd = buildJavaCommand(extraClassPath); From 5014d0ed7e2f69810654003f8dd38078b945cf05 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Wed, 10 Jun 2015 13:34:19 -0700 Subject: [PATCH 07/18] [SPARK-8273] Driver hangs up when yarn shutdown in client mode In client mode, if yarn was shut down with spark application running, the application will hang up after several retries(default: 30) because the exception throwed by YarnClientImpl could not be caught by upper level, we should exit in case that user can not be aware that. The exception we wanna catch is [here](https://github.com/apache/hadoop/blob/branch-2.7.0/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/io/retry/RetryInvocationHandler.java#L122), and I try to fix it refer to [MR](https://github.com/apache/hadoop/blob/branch-2.7.0/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/main/java/org/apache/hadoop/mapred/ClientServiceDelegate.java#L320). Author: WangTaoTheTonic Closes #6717 from WangTaoTheTonic/SPARK-8273 and squashes the following commits: 28752d6 [WangTaoTheTonic] catch the throwed exception --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ec9402afff329..da1ec2a0fe2e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -29,6 +29,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects @@ -826,6 +827,9 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + case NonFatal(e) => + logError(s"Failed to contact YARN for application $appId.", e) + return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState From 96a7c888d806adfdb2c722025a1079ed7eaa2052 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 10 Jun 2015 15:03:40 -0700 Subject: [PATCH 08/18] [SPARK-2774] Set preferred locations for reduce tasks Set preferred locations for reduce tasks. The basic design is that we maintain a map from reducerId to a list of (sizes, locations) for each shuffle. We then set the preferred locations to be any machines that have 20% of more of the output that needs to be read by the reduce task. This will result in at most 5 preferred locations for each reduce task. Selecting the preferred locations involves O(# map tasks * # reduce tasks) computation, so we restrict this feature to cases where we have fewer than 1000 map tasks and 1000 reduce tasks. Author: Shivaram Venkataraman Closes #6652 from shivaram/reduce-locations and squashes the following commits: 492e25e [Shivaram Venkataraman] Remove unused import 2ef2d39 [Shivaram Venkataraman] Address code review comments 897a914 [Shivaram Venkataraman] Remove unused hash map f5be578 [Shivaram Venkataraman] Use fraction of map outputs to determine locations Also removes caching of preferred locations to make the API cleaner 68bc29e [Shivaram Venkataraman] Fix line length 1090b58 [Shivaram Venkataraman] Change flag name 77ce7d8 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations e5d56bd [Shivaram Venkataraman] Add flag to turn off locality for shuffle deps 6cfae98 [Shivaram Venkataraman] Filter out zero blocks, rename variables 9d5831a [Shivaram Venkataraman] Address some more comments 8e31266 [Shivaram Venkataraman] Fix style 0df3180 [Shivaram Venkataraman] Address code review comments e7d5449 [Shivaram Venkataraman] Fix merge issues ad7cb53 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations df14cee [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations 5093aea [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations 0171d3c [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations bc4dfd6 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations 774751b [Shivaram Venkataraman] Fix bug introduced by line length adjustment 34d0283 [Shivaram Venkataraman] Fix style issues 3b464b7 [Shivaram Venkataraman] Set preferred locations for reduce tasks This is another attempt at #1697 addressing some of the earlier concerns. This adds a couple of thresholds based on number map and reduce tasks beyond which we don't use preferred locations for reduce tasks. --- .../org/apache/spark/MapOutputTracker.scala | 49 +++++++++++- .../apache/spark/scheduler/DAGScheduler.scala | 37 ++++++++- .../apache/spark/MapOutputTrackerSuite.scala | 35 +++++++++ .../spark/scheduler/DAGSchedulerSuite.scala | 76 +++++++++++++++---- 4 files changed, 177 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c8..862ffe868f58f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + * + * This method is not thread-safe. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + + if (mapStatuses.contains(shuffleId)) { + val statuses = mapStatuses(shuffleId) + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.map(_._1).toArray) + } + } + } + None + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 75a567fb31520..aea6674ed20be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -137,6 +137,22 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + // Flag to control if reduce tasks are assigned preferred locations + private val shuffleLocalityEnabled = + sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + // Number of map, reduce tasks above which we do not assign preferred locations + // based on map output sizes. We limit the size of jobs for which assign preferred locations + // as computing the top locations by size becomes expensive. + private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. + // Making this larger will focus on fewer locations where most data can be read locally, but + // may lead to more delay in scheduling if those locations are busy. + private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -1384,17 +1400,32 @@ class DAGScheduler( if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. + rdd.dependencies.foreach { case n: NarrowDependency[_] => + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } + case s: ShuffleDependency[_, _, _] => + // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION + // of data as preferred locations + if (shuffleLocalityEnabled && + rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && + s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => } Nil diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fab69678d040..7a1961137cce5 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -205,4 +205,39 @@ class MapOutputTrackerSuite extends SparkFunSuite { // masterTracker.stop() // this throws an exception rpcEnv.shutdown() } + + test("getLocationsWithLargestOutputs with multiple outputs in same machine") { + val rpcEnv = createRpcEnv("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + // Setup 3 map tasks + // on hostA with output size 2 + // on hostA with output size 2 + // on hostB with output size 3 + tracker.registerShuffle(10, 3) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L))) + tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L))) + + // When the threshold is 50%, only host A should be returned as a preferred location + // as it has 4 out of 7 bytes of output. + val topLocs50 = tracker.getLocationsWithLargestOutputs(10, 0, 1, 0.5) + assert(topLocs50.nonEmpty) + assert(topLocs50.get.size === 1) + assert(topLocs50.get.head === BlockManagerId("a", "hostA", 1000)) + + // When the threshold is 20%, both hosts should be returned as preferred locations. + val topLocs20 = tracker.getLocationsWithLargestOutputs(10, 0, 1, 0.2) + assert(topLocs20.nonEmpty) + assert(topLocs20.get.size === 2) + assert(topLocs20.get.toSet === + Seq(BlockManagerId("a", "hostA", 1000), BlockManagerId("b", "hostB", 1000)).toSet) + + tracker.stop() + rpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 47b2868753c0e..833b600746e90 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -490,8 +490,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -501,7 +501,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) @@ -517,8 +517,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) @@ -560,18 +560,18 @@ class DAGSchedulerSuite assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent( - taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -800,6 +800,50 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("reduce tasks should be placed locally with map output") { + // Create an shuffleMapRdd with 1 partition + val shuffleMapRdd = new MyRDD(sc, 1, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"))) + + // Reducer should run on the same host that map task ran + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(Seq("hostA"))) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + + test("reduce task locality preferences should only include machines with largest map outputs") { + val numMapTasks = 4 + // Create an shuffleMapRdd with more partitions + val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + val statuses = (1 to numMapTasks).map { i => + (Success, makeMapStatus("host" + i, 1, (10*i).toByte)) + } + complete(taskSets(0), statuses) + + // Reducer should prefer the last 3 hosts as they have 20%, 30% and 40% of data + val hosts = (1 to numMapTasks).map(i => "host" + i).reverse.take(numMapTasks - 1) + + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(hosts)) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -807,12 +851,12 @@ class DAGSchedulerSuite private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { assert(hosts.size === taskSet.tasks.size) for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { - assert(taskLocs.map(_.host) === expectedLocs) + assert(taskLocs.map(_.host).toSet === expectedLocs.toSet) } } - private def makeMapStatus(host: String, reduces: Int): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(2)) + private def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) From b928f543845ddd39e914a0e8f0b0205fd86100c5 Mon Sep 17 00:00:00 2001 From: Paavo Date: Wed, 10 Jun 2015 23:17:42 +0100 Subject: [PATCH 09/18] [SPARK-8200] [MLLIB] Check for empty RDDs in StreamingLinearAlgorithm Test cases for both StreamingLinearRegression and StreamingLogisticRegression, and code fix. Edit: This contribution is my original work and I license the work to the project under the project's open source license. Author: Paavo Closes #6713 from pparkkin/streamingmodel-empty-rdd and squashes the following commits: ff5cd78 [Paavo] Update strings to use interpolation. db234cf [Paavo] Use !rdd.isEmpty. 54ad89e [Paavo] Test case for empty stream. 393e36f [Paavo] Ignore empty RDDs. 0bfc365 [Paavo] Test case for empty stream. --- .../regression/StreamingLinearAlgorithm.scala | 14 ++++++++------ .../StreamingLogisticRegressionSuite.scala | 17 +++++++++++++++++ .../StreamingLinearRegressionSuite.scala | 18 ++++++++++++++++++ 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index aee51bf22d8d0..141052ba813ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,13 +83,15 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = Some(algorithm.run(rdd, model.get.weights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + if (!rdd.isEmpty) { + model = Some(algorithm.run(rdd, model.get.weights)) + logInfo(s"Model updated at time ${time.toString}") + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo(s"Current model: weights, ${display}") } - logInfo("Current model: weights, %s".format (display)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index e98b61e13e21f..fd653296c9d97 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert(error.head > 0.8 & error.last < 0.2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + val numBatches = 10 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 9a379406d5061..f5e2d31056cbd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert((error.head - error.last) > 2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + val numBatches = 10 + val nPoints = 100 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } From 37719e0cd0b00cc5ffee0ebe1652d465a574db0f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Jun 2015 16:55:39 -0700 Subject: [PATCH 10/18] [SPARK-8189] [SQL] use Long for TimestampType in SQL This PR change to use Long as internal type for TimestampType for efficiency, which means it will the precision below 100ns. Author: Davies Liu Closes #6733 from davies/timestamp and squashes the following commits: d9565fa [Davies Liu] remove print 65cf2f1 [Davies Liu] fix Timestamp in SparkR 86fecfb [Davies Liu] disable two timestamp tests 8f77ee0 [Davies Liu] fix scala style 246ee74 [Davies Liu] address comments 309d2e1 [Davies Liu] use Long for TimestampType in SQL --- .../scala/org/apache/spark/api/r/SerDe.scala | 17 +++-- python/pyspark/sql/types.py | 11 ++++ .../scala/org/apache/spark/sql/BaseRow.java | 6 ++ .../main/scala/org/apache/spark/sql/Row.scala | 8 ++- .../sql/catalyst/CatalystTypeConverters.scala | 13 +++- .../spark/sql/catalyst/expressions/Cast.scala | 62 +++++++++---------- .../expressions/SpecificMutableRow.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateProjection.scala | 10 ++- .../sql/catalyst/expressions/literals.scala | 15 +++-- .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/util/DateUtils.scala | 44 ++++++++++--- .../spark/sql/types/TimestampType.scala | 10 +-- .../sql/catalyst/expressions/CastSuite.scala | 11 ++-- .../sql/catalyst/util/DateUtilsSuite.scala | 40 ++++++++++++ .../spark/sql/types/DataTypeSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 21 +------ .../spark/sql/columnar/ColumnType.scala | 19 +++--- .../sql/execution/SparkSqlSerializer2.scala | 17 ++--- .../spark/sql/execution/debug/package.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 7 ++- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 10 ++- .../apache/spark/sql/json/JacksonParser.scala | 5 +- .../org/apache/spark/sql/json/JsonRDD.scala | 10 ++- .../spark/sql/parquet/ParquetConverter.scala | 9 +-- .../sql/parquet/ParquetTableSupport.scala | 10 +-- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 11 ++-- .../sql/columnar/ColumnarTestUtils.scala | 9 +-- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 14 +++-- .../execution/HiveCompatibilitySuite.scala | 8 ++- .../spark/sql/hive/HiveInspectors.scala | 20 +++--- .../apache/spark/sql/hive/TableReader.scala | 4 +- ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 2 +- 36 files changed, 272 insertions(+), 172 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index f8e3f1a79082e..56adc857d4ce0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Date, Time} +import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConversions._ @@ -107,9 +107,12 @@ private[spark] object SerDe { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Time = { - val t = in.readDouble() - new Time((t * 1000L).toLong) + def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { @@ -227,6 +230,9 @@ private[spark] object SerDe { case "java.sql.Time" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) @@ -289,6 +295,9 @@ private[spark] object SerDe { out.writeDouble(value.getTime.toDouble / 1000.0) } + def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b6ec6137c9180..8f286b631f4f0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -19,6 +19,7 @@ import decimal import time import datetime +import calendar import keyword import warnings import json @@ -654,6 +655,8 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, TimestampType): + return True else: return False @@ -707,6 +710,14 @@ def converter(obj): return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) elif isinstance(dataType, UserDefinedType): return lambda obj: dataType.serialize(obj) + elif isinstance(dataType, TimestampType): + + def to_posix_timstamp(dt): + if dt.tzinfo is None: + return int(time.mktime(dt.timetuple()) * 1e7 + dt.microsecond * 10) + else: + return int(calendar.timegm(dt.utctimetuple()) * 1e7 + dt.microsecond * 10) + return to_posix_timstamp else: raise ValueError("Unexpected type %r" % dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java index d138b43a3482b..6584882a62fd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.sql.Date; +import java.sql.Timestamp; import java.util.List; import scala.collection.Seq; @@ -103,6 +104,11 @@ public Date getDate(int i) { throw new UnsupportedOperationException(); } + @Override + public Timestamp getTimestamp(int i) { + throw new UnsupportedOperationException(); + } + @Override public Seq getSeq(int i) { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0d460b634d9b0..8aaf5d7d89154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -260,9 +260,15 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + /** + * Returns the value at position i of date type as java.sql.Timestamp. + * + * @throws ClassCastException when data type does not match. + */ + def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp] + /** * Returns the value at position i of array type as a Scala Seq. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2e7b4c236d8f8..beb82dbc08642 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} -import java.sql.Date +import java.sql.{Timestamp, Date} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -58,6 +58,7 @@ object CatalystTypeConverters { case structType: StructType => StructConverter(structType) case StringType => StringConverter case DateType => DateConverter + case TimestampType => TimestampConverter case dt: DecimalType => BigDecimalConverter case BooleanType => BooleanConverter case ByteType => ByteConverter @@ -274,6 +275,15 @@ object CatalystTypeConverters { override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) } + private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { + override def toCatalystImpl(scalaValue: Timestamp): Long = + DateUtils.fromJavaTimestamp(scalaValue) + override def toScala(catalystValue: Any): Timestamp = + if (catalystValue == null) null + else DateUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) + override def toScalaImpl(row: Row, column: Int): Timestamp = toScala(row.getLong(column)) + } + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) @@ -367,6 +377,7 @@ object CatalystTypeConverters { def convertToCatalyst(a: Any): Any = a match { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) + case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 18102d1acb5b3..8d93957fea2fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -113,7 +113,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) - case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case TimestampType => buildCast[Long](_, + t => UTF8String(timestampToString(DateUtils.toJavaTimestamp(t)))) case _ => buildCast[Any](_, o => UTF8String(o.toString)) } @@ -127,7 +128,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case StringType => buildCast[UTF8String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + buildCast[Long](_, t => t != 0) case DateType => // Hive would return null when cast from date to boolean buildCast[Int](_, d => null) @@ -158,20 +159,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + try DateUtils.fromJavaTimestamp(Timestamp.valueOf(n)) + catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0)) + buildCast[Boolean](_, b => (if (b) 1L else 0)) case LongType => - buildCast[Long](_, l => new Timestamp(l)) + buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => - buildCast[Int](_, i => new Timestamp(i)) + buildCast[Int](_, i => longToTimestamp(i.toLong)) case ShortType => - buildCast[Short](_, s => new Timestamp(s)) + buildCast[Short](_, s => longToTimestamp(s.toLong)) case ByteType => - buildCast[Byte](_, b => new Timestamp(b)) + buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) + buildCast[Int](_, d => DateUtils.toMillisSinceEpoch(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -191,25 +193,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w }) } - private[this] def decimalToTimestamp(d: Decimal) = { - val seconds = Math.floor(d.toDouble).toLong - val bd = (d.toBigDecimal - seconds) * 1000000000 - val nanos = bd.intValue() - - val millis = seconds * 1000 - val t = new Timestamp(millis) - - // remaining fractional portion as nanos - t.setNanos(nanos) - t + private[this] def decimalToTimestamp(d: Decimal): Long = { + (d.toBigDecimal * 10000000L).longValue() } - // Timestamp to long, converting milliseconds to seconds - private[this] def timestampToLong(ts: Timestamp) = Math.floor(ts.getTime / 1000.0).toLong - - private[this] def timestampToDouble(ts: Timestamp) = { - // First part is the seconds since the beginning of time, followed by nanosecs. - Math.floor(ts.getTime / 1000.0).toLong + ts.getNanos.toDouble / 1000000000 + // converting milliseconds to 100ns + private[this] def longToTimestamp(t: Long): Long = t * 10000L + // converting 100ns to seconds + private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 10000000L).toLong + // converting 100ns to seconds in double + private[this] def timestampToDouble(ts: Long): Double = { + ts / 10000000.0 } // Converts Timestamp to string according to Hive TimestampWritable convention @@ -234,7 +228,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) + buildCast[Long](_, t => DateUtils.millisToDays(t / 10000L)) // Hive throws this exception as a Semantic Exception // It is never possible to compare result when hive return with exception, // so we can return null @@ -253,7 +247,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t)) + buildCast[Long](_, t => timestampToLong(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -269,7 +263,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toInt) + buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -285,7 +279,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toShort) + buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -301,7 +295,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toByte) + buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } @@ -334,7 +328,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case DecimalType() => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) case LongType => @@ -358,7 +352,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t)) + buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -374,7 +368,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + buildCast[Long](_, t => timestampToDouble(t).toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index aa4099e4d7bf9..2c884517d62a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -203,6 +203,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case BooleanType => new MutableBoolean case LongType => new MutableLong case DateType => new MutableInt // We use INT for DATE internally + case TimestampType => new MutableLong // We use Long for Timestamp internally case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e95682f952a7b..80aa8fa056146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -122,7 +122,7 @@ class CodeGenContext { case BinaryType => "byte[]" case StringType => stringType case DateType => "int" - case TimestampType => "java.sql.Timestamp" + case TimestampType => "long" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -140,6 +140,7 @@ class CodeGenContext { case FloatType => "Float" case BooleanType => "Boolean" case DateType => "Integer" + case TimestampType => "Long" case _ => javaType(dt) } @@ -155,6 +156,7 @@ class CodeGenContext { case DoubleType => "-1.0" case IntegerType => "-1" case DateType => "-1" + case TimestampType => "-1L" case _ => "null" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 7caf4aaab88bb..274429cd1c55f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -73,7 +73,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => + case (e, i) if e.dataType == dataType + || dataType == IntegerType && e.dataType == DateType + || dataType == LongType && e.dataType == TimestampType => s"case $i: return c$i;" case _ => "" }.mkString("\n ") @@ -96,7 +98,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificMutatorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => + case (e, i) if e.dataType == dataType + || dataType == IntegerType && e.dataType == DateType + || dataType == LongType && e.dataType == TimestampType => s"case $i: { c$i = value; return; }" case _ => "" }.mkString("\n") @@ -119,7 +123,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val nonNull = e.dataType match { case BooleanType => s"$col ? 0 : 1" case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType => s"$col ^ ($col >>> 32)" + case LongType | TimestampType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 297b35b4da94c..833c08a293dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -37,7 +37,7 @@ object Literal { case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case t: Timestamp => Literal(t, TimestampType) + case t: Timestamp => Literal(DateUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) @@ -100,7 +100,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.isNull = "false" ev.primitive = value.toString "" - case FloatType => // This must go before NumericType + case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { super.genCode(ctx, ev) @@ -109,7 +109,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.primitive = s"${value}f" "" } - case DoubleType => // This must go before NumericType + case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { super.genCode(ctx, ev) @@ -118,15 +118,18 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.primitive = s"${value}" "" } - - case ByteType | ShortType => // This must go before NumericType + case ByteType | ShortType => ev.isNull = "false" ev.primitive = s"(${ctx.javaType(dataType)})$value" "" - case dt: NumericType if !dt.isInstanceOf[DecimalType] => + case IntegerType | DateType => ev.isNull = "false" ev.primitive = value.toString "" + case TimestampType | LongType => + ev.isNull = "false" + ev.primitive = s"${value}L" + "" // eval() version may be faster for non-primitive types case other => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3cbdfdfb13847..2c49352874fc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -254,9 +254,9 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" }) - case TimestampType => - // java.sql.Timestamp does not have compare() - super.genCode(ctx, ev) + case DateType | TimestampType => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) case other => defineCodeGen (ctx, ev, { (c1, c2) => s"$c1.compare($c2) $symbol 0" }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index ad649acf536f9..5cadc141af1df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.sql.Date +import java.sql.{Timestamp, Date} import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast */ object DateUtils { private val MILLIS_PER_DAY = 86400000 + private val HUNDRED_NANOS_PER_SECOND = 10000000L // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { @@ -45,17 +46,17 @@ object DateUtils { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } - private def toMillisSinceEpoch(days: Int): Long = { + def toMillisSinceEpoch(days: Int): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) } - def fromJavaDate(date: java.sql.Date): Int = { + def fromJavaDate(date: Date): Int = { javaDateToDays(date) } - def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { - new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(toMillisSinceEpoch(daysSinceEpoch)) } def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) @@ -64,9 +65,9 @@ object DateUtils { if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) + Timestamp.valueOf(s) } else { - java.sql.Date.valueOf(s) + Date.valueOf(s) } } else if (s.endsWith("Z")) { // this is zero timezone of ISO8601 @@ -87,4 +88,33 @@ object DateUtils { ISO8601GMT.parse(s) } } + + /** + * Return a java.sql.Timestamp from number of 100ns since epoch + */ + def toJavaTimestamp(num100ns: Long): Timestamp = { + // setNanos() will overwrite the millisecond part, so the milliseconds should be + // cut off at seconds + var seconds = num100ns / HUNDRED_NANOS_PER_SECOND + var nanos = num100ns % HUNDRED_NANOS_PER_SECOND + // setNanos() can not accept negative value + if (nanos < 0) { + nanos += HUNDRED_NANOS_PER_SECOND + seconds -= 1 + } + val t = new Timestamp(seconds * 1000) + t.setNanos(nanos.toInt * 100) + t + } + + /** + * Return the number of 100ns since epoch from java.sql.Timestamp. + */ + def fromJavaTimestamp(t: Timestamp): Long = { + if (t != null) { + t.getTime() * 10000L + (t.getNanos().toLong / 100) % 10000L + } else { + 0L + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index aebabfc475925..a558641fcfed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.sql.Timestamp - import scala.math.Ordering import scala.reflect.runtime.universe.typeTag @@ -38,18 +36,16 @@ class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Timestamp + private[sql] type InternalType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) - } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the TimestampType is 12 bytes. */ - override def defaultSize: Int = 12 + override def defaultSize: Int = 8 private[spark] override def asNullable: TimestampType = this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5bc7c30eee1b6..3aca94db3bd8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** @@ -137,7 +138,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) - checkEvaluation(cast(cast(ts, StringType), TimestampType), ts) + checkEvaluation(cast(cast(ts, StringType), TimestampType), DateUtils.fromJavaTimestamp(ts)) // all convert to string type to check checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) @@ -269,9 +270,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.002f) checkEvaluation(cast(ts, DoubleType), 15.002) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), ts) - checkEvaluation(cast(cast(tss, IntegerType), TimestampType), ts) - checkEvaluation(cast(cast(tss, LongType), TimestampType), ts) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), DateUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, LongType), TimestampType), DateUtils.fromJavaTimestamp(ts)) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) @@ -283,7 +284,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal(1)) // A test for higher precision than millis - checkEvaluation(cast(cast(0.00000001, TimestampType), DoubleType), 0.00000001) + checkEvaluation(cast(cast(0.0000001, TimestampType), DoubleType), 0.0000001) checkEvaluation(cast(Double.NaN, TimestampType), null) checkEvaluation(cast(1.0 / 0.0, TimestampType), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala new file mode 100644 index 0000000000000..a4245545ffc1d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.sql.Timestamp + +import org.apache.spark.SparkFunSuite + + +class DateUtilsSuite extends SparkFunSuite { + + test("timestamp") { + val now = new Timestamp(System.currentTimeMillis()) + now.setNanos(100) + val ns = DateUtils.fromJavaTimestamp(now) + assert(ns % 10000000L == 1) + assert(DateUtils.toJavaTimestamp(ns) == now) + + List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => + val ts = DateUtils.toJavaTimestamp(t) + assert(DateUtils.fromJavaTimestamp(ts) == t) + assert(DateUtils.toJavaTimestamp(DateUtils.fromJavaTimestamp(ts)) == ts) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 261c4fcad24aa..077c0ad70ac4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -190,7 +190,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType, 12) + checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b0f983c180673..83881a3687090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { @@ -234,22 +232,7 @@ private[sql] class StringColumnStats extends ColumnStats { private[sql] class DateColumnStats extends IntColumnStats -private[sql] class TimestampColumnStats extends ColumnStats { - protected var upper: Timestamp = null - protected var lower: Timestamp = null - - override def gatherStats(row: Row, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Timestamp] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += TIMESTAMP.defaultSize - } - } - - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) -} +private[sql] class TimestampColumnStats extends LongColumnStats private[sql] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: Row, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 20be5ca9d0046..c9c4d630fb5f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -355,22 +354,20 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer): Timestamp = { - val timestamp = new Timestamp(buffer.getLong()) - timestamp.setNanos(buffer.getInt()) - timestamp +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { + override def extract(buffer: ByteBuffer): Long = { + buffer.getLong } - override def append(v: Timestamp, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime).putInt(v.getNanos) + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) } - override def getField(row: Row, ordinal: Int): Timestamp = { - row(ordinal).asInstanceOf[Timestamp] + override def getField(row: Row, ordinal: Int): Long = { + row(ordinal).asInstanceOf[Long] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row(ordinal) = value } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 256d527d7b636..60f3b2d539ffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.execution import java.io._ import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.ClassTag -import org.apache.spark.serializer._ import org.apache.spark.Logging +import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ /** @@ -304,11 +303,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val timestamp = row.getAs[java.sql.Timestamp](i) - val time = timestamp.getTime - val nanos = timestamp.getNanos - out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. - out.writeInt(nanos) // Write the nanoseconds part. + out.writeLong(row.getAs[Long](i)) } case StringType => @@ -429,11 +424,7 @@ private[sql] object SparkSqlSerializer2 { if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - val time = in.readLong() // Read the milliseconds value. - val nanos = in.readInt() // Read the nanoseconds part. - val timestamp = new Timestamp(time) - timestamp.setNanos(nanos) - mutableRow.update(i, timestamp) + mutableRow.update(i, in.readLong()) } case StringType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dffb265601bdb..720b529d5946f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -170,6 +170,8 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (_: Int, DateType) => + case (_: Long, TimestampType) => case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 342587904789a..955b478a4882f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -148,6 +148,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (t: Long, TimestampType) => DateUtils.toJavaTimestamp(t) case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal @@ -186,10 +187,12 @@ object EvaluatePython { }): Row case (c: java.util.Calendar, DateType) => - DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) + DateUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) case (c: java.util.Calendar, TimestampType) => - new java.sql.Timestamp(c.getTime().getTime()) + c.getTimeInMillis * 10000L + case (t: java.sql.Timestamp, TimestampType) => + DateUtils.fromJavaTimestamp(t) case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index db68b9c86db1b..9028d5ed72c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -385,7 +385,7 @@ private[sql] class JDBCRDD( // DateUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { - mutableRow.update(i, DateUtils.fromJavaDate(dateVal)) + mutableRow.setInt(i, DateUtils.fromJavaDate(dateVal)) } else { mutableRow.update(i, null) } @@ -417,7 +417,13 @@ private[sql] class JDBCRDD( case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case TimestampConversion => + val t = rs.getTimestamp(pos) + if (t != null) { + mutableRow.setLong(i, DateUtils.fromJavaTimestamp(t)) + } else { + mutableRow.update(i, null) + } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 0e223758051a6..4e07cf36ae434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.json import java.io.ByteArrayOutputStream -import java.sql.Timestamp import scala.collection.Map @@ -65,10 +64,10 @@ private[sql] object JacksonParser { DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime) case (VALUE_STRING, TimestampType) => - new Timestamp(DateUtils.stringToTime(parser.getText).getTime) + DateUtils.stringToTime(parser.getText).getTime * 10000L case (VALUE_NUMBER_INT, TimestampType) => - new Timestamp(parser.getLongValue) + parser.getLongValue * 10000L case (_, StringType) => val writer = new ByteArrayOutputStream() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 7e1e21f5fbb99..fb0d137bdbfdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.json -import java.sql.Timestamp - import scala.collection.Map import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} @@ -398,11 +396,11 @@ private[sql] object JsonRDD extends Logging { } } - private def toTimestamp(value: Any): Timestamp = { + private def toTimestamp(value: Any): Long = { value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) + case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L + case value: java.lang.Long => value * 10000L + case value: java.lang.String => DateUtils.stringToTime(value).getTime * 10000L } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 85c2ce740fe52..ddc5097f88fb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -28,6 +28,7 @@ import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Co import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ import org.apache.spark.sql.parquet.timestamp.NanoTime @@ -266,8 +267,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a Timestamp value from a Parquet Int96Value */ - protected[parquet] def readTimestamp(value: Binary): Timestamp = { - CatalystTimestampConverter.convertToTimestamp(value) + protected[parquet] def readTimestamp(value: Binary): Long = { + DateUtils.fromJavaTimestamp(CatalystTimestampConverter.convertToTimestamp(value)) } } @@ -401,7 +402,7 @@ private[parquet] class CatalystPrimitiveRowConverter( current.setInt(fieldIndex, value) override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.update(fieldIndex, value) + current.setInt(fieldIndex, value) override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) @@ -425,7 +426,7 @@ private[parquet] class CatalystPrimitiveRowConverter( current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, readTimestamp(value)) + current.setLong(fieldIndex, readTimestamp(value)) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 89db408b1c382..e03dbdec0491d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -29,6 +29,7 @@ import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** @@ -204,7 +205,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) - case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) + case TimestampType => writeTimestamp(value.asInstanceOf[Long]) case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) @@ -311,8 +312,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } - private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { - val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) + private[parquet] def writeTimestamp(ts: Long): Unit = { + val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp( + DateUtils.toJavaTimestamp(ts)) writer.addBinary(binaryNanoTime) } } @@ -357,7 +359,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) + case TimestampType => writeTimestamp(record(index).asInstanceOf[Long]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72e60d9aa75cb..17a3cec48b856 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 339e719f39f16..16836628cb73a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -31,7 +31,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(Long.MaxValue, Long.MinValue, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index a1e76eaa982cc..8421e670ff05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,17 +18,16 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp -import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.serializer.KryoRegistrator +import com.esotericsoftware.kryo.{Kryo, Serializer} -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 @@ -36,7 +35,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12, + FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => @@ -69,7 +68,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, new Timestamp(0L), 12) + checkActualSize(TIMESTAMP, 0L, 8) val binary = Array.fill[Byte](4)(0: Byte) checkActualSize(BINARY, binary, 4 + 4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 75d993e563e06..c5d38595c0bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} +import org.apache.spark.sql.types.{AtomicType, DataType, Decimal, UTF8String} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -52,10 +50,7 @@ object ColumnarTestUtils { case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() - case TIMESTAMP => - val timestamp = new Timestamp(Random.nextLong()) - timestamp.setNanos(Random.nextInt(999999999)) - timestamp + case TIMESTAMP => Random.nextLong() case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 49d348c3ed21b..69ab1c292d221 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -326,7 +326,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) } test("test DATE types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index d889c7be17ce7..fca24364fe6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -76,21 +76,25 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion( Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) - checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber)), + enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" checkTypePromotion( DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" - checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(10801000)), + enforceCorrectType(ISO8601Time2, TimestampType)) checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0693c7ea5b332..82c0b494598a8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -252,7 +252,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part14.*", // These work alone but fail when run with other tests... // the answer is sensitive for jdk version - "udf_java_method" + "udf_java_method", + + // Spark SQL use Long for TimestampType, lose the precision under 100ns + "timestamp_1", + "timestamp_2" ) /** @@ -795,8 +799,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", - "timestamp_1", - "timestamp_2", "timestamp_3", "timestamp_comparison", "timestamp_lazy", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c466203cd0220..1f14cba78f479 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -250,7 +250,8 @@ private[hive] trait HiveInspectors { PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => - poi.getWritableConstantValue.getTimestamp.clone() + val t = poi.getWritableConstantValue + t.getSeconds * 10000000L + t.getNanos / 100L case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDoubleObjectInspector => @@ -313,11 +314,11 @@ private[hive] trait HiveInspectors { case x: DateObjectInspector if x.preferWritable() => DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) - // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object - // if next timestamp is null, so Timestamp object is cloned case x: TimestampObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).getTimestamp.clone() - case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() + val t = x.getPrimitiveWritableObject(data) + t.getSeconds * 10000000L + t.getNanos / 100 + case ti: TimestampObjectInspector => + DateUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) } case li: ListObjectInspector => @@ -356,6 +357,9 @@ private[hive] trait HiveInspectors { case _: JavaDateObjectInspector => (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + case _: JavaTimestampObjectInspector => + (o: Any) => DateUtils.toJavaTimestamp(o.asInstanceOf[Long]) + case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { @@ -465,7 +469,7 @@ private[hive] trait HiveInspectors { case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) - case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] + case _: TimestampObjectInspector => DateUtils.toJavaTimestamp(a.asInstanceOf[Long]) } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs @@ -727,7 +731,7 @@ private[hive] trait HiveInspectors { TypeInfoFactory.voidTypeInfo, null) private def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes) private def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) @@ -776,7 +780,7 @@ private[hive] trait HiveInspectors { if (value == null) { null } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + new hiveIo.TimestampWritable(DateUtils.toJavaTimestamp(value.asInstanceOf[Long])) } private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 334bfccc9d200..d3c82d8c2e326 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -363,10 +363,10 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + row.setLong(ordinal, DateUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) + row.setInt(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 index 27de46fdf22ac..84a31a5a6970b 100644 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -1 +1 @@ --0.0010000000000000009 +-0.001 From 6a47114bc297f0bce874e425feb1c24a5c26cef0 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 10 Jun 2015 18:19:12 -0700 Subject: [PATCH 11/18] [SPARK-8285] [SQL] CombineSum should be calculated as unlimited decimal first case cs CombineSum(expr) => val calcType = expr.dataType expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited case _ => expr.dataType } calcType is always expr.dataType. credits are all belong to IntelliJ Author: navis.ryu Closes #6736 from navis/SPARK-8285 and squashes the following commits: 20382c1 [navis.ryu] [SPARK-8285] [SQL] CombineSum should be calculated as unlimited decimal first --- .../org/apache/spark/sql/execution/GeneratedAggregate.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 3e27c1bde2dfd..af3791734d0c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -118,7 +118,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case cs @ CombineSum(expr) => - val calcType = expr.dataType + val calcType = expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -129,7 +129,7 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", calcType, nullable = true)() val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its From 4e42842e82e058d54329bd66185d8a7e77ab335a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jun 2015 18:22:47 -0700 Subject: [PATCH 12/18] [SPARK-8164] transformExpressions should support nested expression sequence Currently we only support `Seq[Expression]`, we should handle cases like `Seq[Seq[Expression]]` so that we can remove the unnecessary `GroupExpression`. Author: Wenchen Fan Closes #6706 from cloud-fan/clean and squashes the following commits: 60a1193 [Wenchen Fan] support nested expression sequence and remove GroupExpression --- .../sql/catalyst/analysis/Analyzer.scala | 6 ++--- .../sql/catalyst/expressions/Expression.scala | 12 ---------- .../spark/sql/catalyst/plans/QueryPlan.scala | 22 +++++++++---------- .../plans/logical/basicOperators.scala | 2 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 14 ++++++++++++ .../apache/spark/sql/execution/Expand.scala | 4 ++-- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c4f12cfe87993..cbd8def4f1d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -172,8 +172,8 @@ class Analyzer( * expressions which equal GroupBy expressions with Literal(null), if those expressions * are not set for this grouping set (according to the bit mask). */ - private[this] def expand(g: GroupingSets): Seq[GroupExpression] = { - val result = new scala.collection.mutable.ArrayBuffer[GroupExpression] + private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = { + val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] g.bitmasks.foreach { bitmask => // get the non selected grouping attributes according to the bit mask @@ -194,7 +194,7 @@ class Analyzer( Literal.create(bitmask, IntegerType) }) - result += GroupExpression(substitution) + result += substitution } result.toSeq diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a05794f1dbd86..63dd5f9854aed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -239,18 +239,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } -// TODO Semantically we probably not need GroupExpression -// All we need is holding the Seq[Expression], and ONLY used in doing the -// expressions transformation correctly. Probably will be removed since it's -// not like a real expressions. -case class GroupExpression(children: Seq[Expression]) extends Expression { - self: Product => - override def eval(input: Row): Any = throw new UnsupportedOperationException - override def nullable: Boolean = false - override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException -} - /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index eff5c61644944..2f545bb432165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionDown(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } @@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionUp(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e77e5c27b687a..963c7820914f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -226,7 +226,7 @@ case class Window( * @param child Child operator */ case class Expand( - projections: Seq[GroupExpression], + projections: Seq[Seq[Expression]], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def statistics: Statistics = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 67db3d5e6d751..8ec79c3d4d28d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -31,6 +31,11 @@ case class Dummy(optKey: Option[Expression]) extends Expression { override def eval(input: Row): Any = null.asInstanceOf[Any] } +case class ComplexPlan(exprs: Seq[Seq[Expression]]) + extends org.apache.spark.sql.catalyst.plans.logical.LeafNode { + override def output: Seq[Attribute] = Nil +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -220,4 +225,13 @@ class TreeNodeSuite extends SparkFunSuite { assert(expected === actual) } } + + test("transformExpressions on nested expression sequence") { + val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2)))) + val actual = plan.transformExpressions { + case Literal(value, _) => Literal(value.toString) + } + val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) + assert(expected === actual) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index f16ca36909fab..4b601c11924b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit */ @DeveloperApi case class Expand( - projections: Seq[GroupExpression], + projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -49,7 +49,7 @@ case class Expand( // workers via closure. However we can't assume the Projection // is serializable because of the code gen, so we have to // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray + val groups = projections.map(ee => newProjection(ee, child.output)).toArray new Iterator[Row] { private[this] var result: Row = _ From 9fe3adccef687c92ff1ac17d946af089c8e28d66 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 10 Jun 2015 19:55:10 -0700 Subject: [PATCH 13/18] [SPARK-8248][SQL] string function: length Author: Cheng Hao Closes #6724 from chenghao-intel/length and squashes the following commits: aaa3c31 [Cheng Hao] revert the additional change 97148a9 [Cheng Hao] remove the codegen testing temporally ae08003 [Cheng Hao] update the comments 1eb1fd1 [Cheng Hao] simplify the code as commented 3e92d32 [Cheng Hao] use the selectExpr in unit test intead of SQLQuery 3c729aa [Cheng Hao] fix bug for constant null value in codegen 3641f06 [Cheng Hao] keep the length() method for registered function 8e30171 [Cheng Hao] update the code as comment db604ae [Cheng Hao] Add code gen support 548d2ef [Cheng Hao] register the length() 09a0738 [Cheng Hao] add length support --- .../catalyst/analysis/FunctionRegistry.scala | 13 +++++++----- .../sql/catalyst/expressions/Expression.scala | 3 +++ .../expressions/stringOperations.scala | 21 +++++++++++++++++++ .../expressions/StringFunctionsSuite.scala | 12 +++++++++++ .../org/apache/spark/sql/functions.scala | 18 ++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 20 ++++++++++++++++++ 6 files changed, 82 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ba89a5c8d1372..39875d7f216b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -89,14 +89,10 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), - expression[Lower]("lower"), - expression[Substring]("substr"), - expression[Substring]("substring"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), expression[Sqrt]("sqrt"), - expression[Upper]("upper"), // Math functions expression[Acos]("acos"), @@ -132,7 +128,14 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), - expression[Sum]("sum") + expression[Sum]("sum"), + + // string functions + expression[Lower]("lower"), + expression[StringLength]("length"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[Upper]("upper") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 63dd5f9854aed..8c1e4d74f9df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -212,6 +212,9 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + /** * Called by unary expressions to generate a code block that returns null if its parent returns * null, and if not not null, use `f` to generate the expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 856f56488c7a5..345038323ddc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -294,3 +294,24 @@ object Substring { apply(str, pos, Literal(Integer.MAX_VALUE)) } } + +/** + * A function that return the length of the given string expression. + */ +case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: Row): Any = { + val string = child.eval(input) + if (string == null) null else string.asInstanceOf[UTF8String].length + } + + override def toString: String = s"length($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).length()") + } +} + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 2e81296c4e623..d363e631540d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -215,4 +215,16 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluate("abbbbc" rlike regEx, create_row("**")) } } + + test("length for string") { + val regEx = 'a.string.at(0) + checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) + checkEvaluation(StringLength(regEx), 5, create_row("abdef")) + checkEvaluation(StringLength(regEx), 0, create_row("")) + checkEvaluation(StringLength(regEx), null, create_row(null)) + // TODO currently bug in codegen, let's temporally disable this + // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + } + + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b3fc1e6cd987e..083f6b6bceee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.Utils * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions * @groupname window_funcs Window functions + * @groupname string_funcs String functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -1317,6 +1318,23 @@ object functions { */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // String functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the length of a given string value + * @group string_funcs + * @since 1.5.0 + */ + def strlen(e: Column): Column = StringLength(e.expr) + + /** + * Computes the length of a given string column + * @group string_funcs + * @since 1.5.0 + */ + def strlen(columnName: String): Column = strlen(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b93ad39f5da45..171a2151e67ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -109,4 +109,24 @@ class DataFrameFunctionsSuite extends QueryTest { testData2.select(bitwiseNOT($"a")), testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) } + + test("length") { + checkAnswer( + nullStrings.select(strlen($"s"), strlen("s")), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l, l) + }) + + checkAnswer( + nullStrings.selectExpr("length(s)"), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l) + }) + } + + } From 2758ff0a96f03a61e10999b2462acf7a13236b7c Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 10 Jun 2015 20:22:32 -0700 Subject: [PATCH 14/18] [SPARK-8217] [SQL] math function log2 Author: Daoyuan Wang This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #6718 from adrian-wang/udflog2 and squashes the following commits: 3909f48 [Daoyuan Wang] math function: log2 --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 17 ++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 6 ++++++ .../org/apache/spark/sql/functions.scala | 20 +++++++++++++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +++++++++++ 5 files changed, 54 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 39875d7f216b2..a7816e327526f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -111,6 +111,7 @@ object FunctionRegistry { expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Pi]("pi"), + expression[Log2]("log2"), expression[Pow]("pow"), expression[Rint]("rint"), expression[Signum]("signum"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e1d8c9a0cdb5a..97e960b8d6422 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log2(child: Expression) + extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } +} + case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 1fe69059d39da..864c954ee82cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) } + test("log2") { + def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) + testUnary(Log2, f, (0 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + } + test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 083f6b6bceee8..c5b77724aae17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1084,7 +1084,7 @@ object functions { def log(columnName: String): Column = log(Column(columnName)) /** - * Computes the logarithm of the given value in Base 10. + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -1092,7 +1092,7 @@ object functions { def log10(e: Column): Column = Log10(e.expr) /** - * Computes the logarithm of the given value in Base 10. + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -1124,6 +1124,22 @@ object functions { */ def pi(): Column = Pi() + /** + * Computes the logarithm of the given column in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(expr: Column): Column = Log2(expr.expr) + + /** + * Computes the logarithm of the given value in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(columnName: String): Column = log2(Column(columnName)) + /** * Returns the value of the first argument raised to the power of the second argument. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 171a2151e67ae..659b64c185f43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -128,5 +128,17 @@ class DataFrameFunctionsSuite extends QueryTest { }) } + test("log2 functions test") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.select(log2("b") + log2("a")), + Row(1)) + checkAnswer( + ctx.sql("SELECT LOG2(8)"), + Row(3)) + checkAnswer( + ctx.sql("SELECT LOG2(null)"), + Row(null)) + } } From a777eb04bf981312b640326607158f78dd4163cd Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 10 Jun 2015 21:13:47 -0700 Subject: [PATCH 15/18] [HOTFIX] Adding more contributor name bindings --- dev/create-release/known_translations | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 0a599b5a65549..bbd4330e1c2e5 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -91,3 +91,45 @@ zapletal-martin - Martin Zapletal zuxqoj - Shekhar Bansal mingyukim - Mingyu Kim sigmoidanalytics - Mayur Rustagi +AiHe - Ai He +BenFradet - Ben Fradet +FavioVazquez - Favio Vazquez +JaysonSunshine - Jayson Sunshine +Liuchang0812 - Liu Chang +Sephiroth-Lin - Sephiroth Lin +baishuo - Cheng Lian +daisukebe - Shixiong Zhu +dobashim - Masaru Dobashi +ehnalis - Zoltan Zvara +emres - Emre Sevinc +gchen - Guancheng Chen +haiyangsea - Haiyang Sea +hlin09 - Hao Lin +hqzizania - Qian Huang +jeanlyn - Jean Lyn +jerluc - Jeremy A. Lucas +jrabary - Jaonary Rabarisoa +judynash - Judy Nash +kaka1992 - Chen Song +ksonj - Kalle Jepsen +kuromatsu-nobuyuki - Nobuyuki Kuromatsu +lazyman500 - Dong Xu +leahmcguire - Leah McGuire +mbittmann - Mark Bittmann +mbonaci - Marko Bonaci +meawoppl - Matthew Goodman +nyaapa - Arsenii Krasikov +phatak-dev - Madhukara Phatak +prabeesh - Prabeesh K +rakeshchalasani - Rakesh Chalasani +raschild - Marcelo Vanzin +rekhajoshm - Rekha Joshi +sisihj - June He +szheng79 - Shuai Zheng +ted-yu - Andrew Or +texasmichelle - Michelle Casbon +vinodkc - Vinod KC +yongtang - Yong Tang +ypcat - Pei-Lun Lee +zhichao-li - Zhichao Li +zzcclp - Zhichao Zhang From e84545fa771dde90de5675a9c551fe287af6f7fb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 10 Jun 2015 22:56:36 -0700 Subject: [PATCH 16/18] [HOTFIX] Fixing errors in name mappings --- dev/create-release/known_translations | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index bbd4330e1c2e5..5f2671a6e5053 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -97,8 +97,6 @@ FavioVazquez - Favio Vazquez JaysonSunshine - Jayson Sunshine Liuchang0812 - Liu Chang Sephiroth-Lin - Sephiroth Lin -baishuo - Cheng Lian -daisukebe - Shixiong Zhu dobashim - Masaru Dobashi ehnalis - Zoltan Zvara emres - Emre Sevinc @@ -122,11 +120,9 @@ nyaapa - Arsenii Krasikov phatak-dev - Madhukara Phatak prabeesh - Prabeesh K rakeshchalasani - Rakesh Chalasani -raschild - Marcelo Vanzin rekhajoshm - Rekha Joshi sisihj - June He szheng79 - Shuai Zheng -ted-yu - Andrew Or texasmichelle - Michelle Casbon vinodkc - Vinod KC yongtang - Yong Tang From 6b68366df345d4572cf138f9efe17e23d0d1971e Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Thu, 11 Jun 2015 08:40:46 +0100 Subject: [PATCH 17/18] [SPARK-8289] Specify stack size for consistency with Java tests - resolves test failures This change is a simple one and specifies a stack size of 4096k instead of the vendor default for Java tests (the defaults vary between Java vendors). This remedies test failures observed with JavaALSSuite with IBM and Oracle Java owing to a lower default size in comparison to the size with OpenJDK. 4096k is a suitable default where the tests pass with each Java vendor tested. The alternative is to reduce the number of iterations in the test (no observed failures with 5 iterations instead of 15). -Xss works with Oracle's HotSpot VM, IBM's J9 VM and OpenJDK (IcedTea). I have ensured this does not have any negative implications for other tests. Author: Adam Roberts Author: a-roberts Closes #6727 from a-roberts/IncJavaStackSize and squashes the following commits: ab40aea [Adam Roberts] Specify stack size for SBT builds 5032d8d [a-roberts] Update pom.xml --- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index e9700a5d7b149..6d4f717d4931b 100644 --- a/pom.xml +++ b/pom.xml @@ -1244,7 +1244,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m