diff --git a/build/mvn b/build/mvn index a87c5a26230c8..53babf54debb6 100755 --- a/build/mvn +++ b/build/mvn @@ -34,14 +34,14 @@ install_app() { local binary="${_DIR}/$3" # setup `curl` and `wget` silent options if we're running on Jenkins - local curl_opts="" + local curl_opts="-L" local wget_opts="" if [ -n "$AMPLAB_JENKINS" ]; then - curl_opts="-s" - wget_opts="--quiet" + curl_opts="-s ${curl_opts}" + wget_opts="--quiet ${wget_opts}" else - curl_opts="--progress-bar" - wget_opts="--progress=bar:force" + curl_opts="--progress-bar ${curl_opts}" + wget_opts="--progress=bar:force ${wget_opts}" fi if [ -z "$3" -o ! -f "$binary" ]; then diff --git a/core/pom.xml b/core/pom.xml index 2dc5f747f2b71..4daaf88147142 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -132,6 +132,13 @@ jetty-servlet compile + + + org.eclipse.jetty.orbit + javax.servlet + ${orbit.version} + org.apache.commons diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a7adddb6c83ec..24490fddc5c6a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -25,29 +25,37 @@ import java.net.URI import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger import java.util.UUID.randomUUID + import scala.collection.{Map, Set} import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} + +import akka.actor.Props + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.executor.TriggerThreadDump -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, + FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ @@ -1016,12 +1024,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ - def addFile(path: String) { + def addFile(path: String): Unit = { + addFile(path, false) + } + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * + * A directory can be given if the recursive option is set to true. Currently directories are only + * supported for Hadoop-supported filesystems. + */ + def addFile(path: String, recursive: Boolean): Unit = { val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case "local" => "file:" + uri.getPath - case _ => path + val schemeCorrectedPath = uri.getScheme match { + case null | "local" => "file:" + uri.getPath + case _ => path + } + + val hadoopPath = new Path(schemeCorrectedPath) + val scheme = new URI(schemeCorrectedPath).getScheme + if (!Array("http", "https", "ftp").contains(scheme)) { + val fs = hadoopPath.getFileSystem(hadoopConfiguration) + if (!fs.exists(hadoopPath)) { + throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") + } + val isDir = fs.isDirectory(hadoopPath) + if (!isLocal && scheme == "file" && isDir) { + throw new SparkException(s"addFile does not support local directories when not running " + + "local mode.") + } + if (!recursive && isDir) { + throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + + "turned on.") + } + } + + val key = if (!isLocal && scheme == "file") { + env.httpFileServer.addFile(new File(uri.getPath)) + } else { + schemeCorrectedPath } val timestamp = System.currentTimeMillis addedFiles(key) = timestamp @@ -1633,8 +1677,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val schedulingMode = getSchedulingMode.toString val addedJarPaths = addedJars.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq - val environmentDetails = - SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) + val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, + addedFilePaths) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) listenerBus.post(environmentUpdate) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index d68854214ef06..03238e9fa0088 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} @@ -191,6 +191,21 @@ class SparkHadoopUtil extends Logging { val method = context.getClass.getMethod("getConfiguration") method.invoke(context).asInstanceOf[Configuration] } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + def recurse(path: Path) = { + val (directories, leaves) = fs.listStatus(path).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) + } + + val baseStatus = fs.getFileStatus(basePath) + if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3c8a0e40bf785..72d15e65bcde6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -386,8 +386,10 @@ private[spark] object Utils extends Logging { } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * If `useCache` is true, first attempts to fetch the file to a local cache that's shared * across executors running the same application. `useCache` is used mainly for @@ -456,7 +458,6 @@ private[spark] object Utils extends Logging { * * @param url URL that `sourceFile` originated from, for logging purposes. * @param in InputStream to download. - * @param tempFile File path to download `in` to. * @param destFile File path to move `tempFile` to. * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match * `sourceFile` @@ -464,9 +465,11 @@ private[spark] object Utils extends Logging { private def downloadFile( url: String, in: InputStream, - tempFile: File, destFile: File, fileOverwrite: Boolean): Unit = { + val tempFile = File.createTempFile("fetchFileTemp", null, + new File(destFile.getParentFile.getAbsolutePath)) + logInfo(s"Fetching $url to $tempFile") try { val out = new FileOutputStream(tempFile) @@ -505,7 +508,7 @@ private[spark] object Utils extends Logging { removeSourceFile: Boolean = false): Unit = { if (destFile.exists) { - if (!Files.equal(sourceFile, destFile)) { + if (!filesEqualRecursive(sourceFile, destFile)) { if (fileOverwrite) { logInfo( s"File $destFile exists and does not match contents of $url, replacing it with $url" @@ -540,13 +543,44 @@ private[spark] object Utils extends Logging { Files.move(sourceFile, destFile) } else { logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") - Files.copy(sourceFile, destFile) + copyRecursive(sourceFile, destFile) + } + } + + private def filesEqualRecursive(file1: File, file2: File): Boolean = { + if (file1.isDirectory && file2.isDirectory) { + val subfiles1 = file1.listFiles() + val subfiles2 = file2.listFiles() + if (subfiles1.size != subfiles2.size) { + return false + } + subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall { + case (f1, f2) => filesEqualRecursive(f1, f2) + } + } else if (file1.isFile && file2.isFile) { + Files.equal(file1, file2) + } else { + false + } + } + + private def copyRecursive(source: File, dest: File): Unit = { + if (source.isDirectory) { + if (!dest.mkdir()) { + throw new IOException(s"Failed to create directory ${dest.getPath}") + } + val subfiles = source.listFiles() + subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) + } else { + Files.copy(source, dest) } } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * Throws SparkException if the target file already exists and has different contents than * the requested file. @@ -558,14 +592,11 @@ private[spark] object Utils extends Logging { conf: SparkConf, securityMgr: SecurityManager, hadoopConf: Configuration) { - val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath)) val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { logDebug("fetchFile with security enabled") @@ -583,17 +614,44 @@ private[spark] object Utils extends Logging { uc.setReadTimeout(timeout) uc.connect() val in = uc.getInputStream() - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + downloadFile(url, in, targetFile, fileOverwrite) case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) copyFile(url, sourceFile, targetFile, fileOverwrite) case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val fs = getHadoopFileSystem(uri, hadoopConf) - val in = fs.open(new Path(uri)) - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + val path = new Path(uri) + fetchHcfsFile(path, new File(targetDir, path.getName), fs, conf, hadoopConf, fileOverwrite) + } + } + + /** + * Fetch a file or directory from a Hadoop-compatible filesystem. + * + * Visible for testing + */ + private[spark] def fetchHcfsFile( + path: Path, + targetDir: File, + fs: FileSystem, + conf: SparkConf, + hadoopConf: Configuration, + fileOverwrite: Boolean): Unit = { + if (!targetDir.mkdir()) { + throw new IOException(s"Failed to create directory ${targetDir.getPath}") + } + fs.listStatus(path).foreach { fileStatus => + val innerPath = fileStatus.getPath + if (fileStatus.isDir) { + fetchHcfsFile(innerPath, new File(targetDir, innerPath.getName), fs, conf, hadoopConf, + fileOverwrite) + } else { + val in = fs.open(innerPath) + val targetFile = new File(targetDir, innerPath.getName) + downloadFile(innerPath.toString, in, targetFile, fileOverwrite) + } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 8b3c6871a7b39..50f347f1954de 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -17,10 +17,17 @@ package org.apache.spark +import java.io.File + +import com.google.common.base.Charsets._ +import com.google.common.io.Files + import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable +import org.apache.spark.util.Utils + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -72,4 +79,74 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val byteArray2 = converter.convert(bytesWritable) assert(byteArray2.length === 0) } + + test("addFile works") { + val file = File.createTempFile("someprefix", "somesuffix") + val absolutePath = file.getAbsolutePath + try { + Files.write("somewords", file, UTF_8) + val length = file.length() + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(file.getAbsolutePath) + sc.parallelize(Array(1), 1).map(x => { + val gotten = new File(SparkFiles.get(file.getName)) + if (!gotten.exists()) { + throw new SparkException("file doesn't exist") + } + if (length != gotten.length()) { + throw new SparkException( + s"file has different length $length than added file ${gotten.length()}") + } + if (absolutePath == gotten.getAbsolutePath) { + throw new SparkException("file should have been copied") + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive works") { + val pluto = Utils.createTempDir() + val neptune = Utils.createTempDir(pluto.getAbsolutePath) + val saturn = Utils.createTempDir(neptune.getAbsolutePath) + val alien1 = File.createTempFile("alien", "1", neptune) + val alien2 = File.createTempFile("alien", "2", saturn) + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(neptune.getAbsolutePath, true) + sc.parallelize(Array(1), 1).map(x => { + val sep = File.separator + if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { + throw new SparkException("can't access file under root added directory") + } + if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName)) + .exists()) { + throw new SparkException("can't access file in nested directory") + } + if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName)) + .exists()) { + throw new SparkException("file exists that shouldn't") + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive can't add directories by default") { + val dir = Utils.createTempDir() + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + intercept[SparkException] { + sc.addFile(dir.getAbsolutePath) + } + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 39e5d367d676c..2cc5817758cf7 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -371,7 +371,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) val timeout = AkkaUtils.lookupTimeout(conf) intercept[TimeoutException] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout * 2), timeout) } actorSystem.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4544382094f96..fe2b644251157 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,6 +29,9 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -381,4 +384,32 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { require(cnt === 2, "prepare should be called twice") require(time < 500, "preparation time should not count") } + + test("fetch hcfs dir") { + val tempDir = Utils.createTempDir() + val innerTempDir = Utils.createTempDir(tempDir.getPath) + val tempFile = File.createTempFile("someprefix", "somesuffix", innerTempDir) + val targetDir = new File("target-dir") + Files.write("some text", tempFile, UTF_8) + + try { + val path = new Path("file://" + tempDir.getAbsolutePath) + val conf = new Configuration() + val fs = Utils.getHadoopFileSystem(path.toString, conf) + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + assert(targetDir.exists()) + assert(targetDir.isDirectory()) + val newInnerDir = new File(targetDir, innerTempDir.getName) + println("inner temp dir: " + innerTempDir.getName) + targetDir.listFiles().map(_.getName).foreach(println) + assert(newInnerDir.exists()) + assert(newInnerDir.isDirectory()) + val newInnerFile = new File(newInnerDir, tempFile.getName) + assert(newInnerFile.exists()) + assert(newInnerFile.isFile()) + } finally { + Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(targetDir) + } + } } diff --git a/dev/check-license b/dev/check-license index a006f65710d6d..39943f882b6ca 100755 --- a/dev/check-license +++ b/dev/check-license @@ -31,7 +31,7 @@ acquire_rat_jar () { printf "Attempting to fetch rat\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" elif [ $(command -v wget) ]; then wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" else diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 8841f7675d35e..efc4c612937df 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -7,7 +7,9 @@ {{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation - + {% if page.description %} + + {% endif %} {% if page.redirect %} diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index 7e55131754a3f..c2fe6b0e286ce 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Bagel Programming Guide +displayTitle: Bagel Programming Guide +title: Bagel --- **Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** diff --git a/docs/configuration.md b/docs/configuration.md index 8b1d7598c47e4..4c86cb7c16238 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Configuration +displayTitle: Spark Configuration +title: Configuration --- * This will become a table of contents (this text will be scraped). {:toc} diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e298c51f8a5b7..826f6d8f371c7 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: GraphX Programming Guide +displayTitle: GraphX Programming Guide +title: GraphX +description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/index.md b/docs/index.md index 171d6ddad62f3..e006be640e582 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Overview +displayTitle: Spark Overview +title: Overview +description: Apache Spark SPARK_VERSION_SHORT documentation homepage --- Apache Spark is a fast and general-purpose cluster computing system. diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index fc8e732251a30..d1537def851e7 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Tree - MLlib -displayTitle: MLlib - Decision Tree +title: Decision Trees - MLlib +displayTitle: MLlib - Decision Trees --- * Table of contents diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 39c64d06926bf..73728bb35eb96 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Machine Learning Library (MLlib) Programming Guide +title: MLlib +displayTitle: Machine Learning Library (MLlib) Guide +description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, diff --git a/docs/monitoring.md b/docs/monitoring.md index f32cdef240d31..7a5cadc171d6d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -1,6 +1,7 @@ --- layout: global title: Monitoring and Instrumentation +description: Monitoring, metrics, and instrumentation guide for Spark SPARK_VERSION_SHORT --- There are several ways to monitor Spark applications: web UIs, metrics, and external instrumentation. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 6486614e71354..6b365e83fb56d 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1,6 +1,7 @@ --- layout: global title: Spark Programming Guide +description: Spark SPARK_VERSION_SHORT programming guide in Java, Scala and Python --- * This will become a table of contents (this text will be scraped). diff --git a/docs/quick-start.md b/docs/quick-start.md index bf643bb70e153..81143da865cf0 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -1,6 +1,7 @@ --- layout: global title: Quick Start +description: Quick start tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/security.md b/docs/security.md index 6e0a54fbc4ad7..c034ba12ff1fc 100644 --- a/docs/security.md +++ b/docs/security.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Security +displayTitle: Spark Security +title: Security --- Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 22664b419f5cb..38f617d0c836c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark SQL Programming Guide +displayTitle: Spark SQL Programming Guide +title: Spark SQL --- * This will become a table of contents (this text will be scraped). @@ -1107,7 +1108,7 @@ in Hive deployments. have the same input format. * Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNION` type and `DATE` type +* `UNION` type * Unique join * Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e37a2bb37b9a4..96fb12ce5e0b9 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Streaming Programming Guide +displayTitle: Spark Streaming Programming Guide +title: Spark Streaming +description: Spark Streaming programming guide and tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/tuning.md b/docs/tuning.md index efaac9d3d405f..cbd227868b248 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -1,6 +1,8 @@ --- layout: global -title: Tuning Spark +displayTitle: Tuning Spark +title: Tuning +description: Tuning and performance optimization guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 87190e9d002b6..a1f4c1c4a7dab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -134,7 +134,7 @@ object LDAExample { .setTopicConcentration(params.topicConcentration) .setCheckpointInterval(params.checkpointInterval) if (params.checkpointDir.nonEmpty) { - lda.setCheckpointDir(params.checkpointDir.get) + sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 74183864de6ff..3c7ba2a674cae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -54,6 +54,9 @@ import org.apache.spark.util.Utils * - Paper which clearly explains several algorithms, including EM: * Asuncion, Welling, Smyth, and Teh. * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] */ @Experimental class LDA private ( @@ -62,11 +65,10 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointDir: Option[String], private var checkpointInterval: Int) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -202,50 +204,18 @@ class LDA private ( this } - /** - * Directory for storing checkpoint files during learning. - * This is not necessary, but checkpointing helps with recovery (when nodes fail). - * It also helps with eliminating temporary shuffle files on disk, which can be important when - * LDA is run for many iterations. - */ - def getCheckpointDir: Option[String] = checkpointDir - - /** - * Directory for storing checkpoint files during learning. - * This is not necessary, but checkpointing helps with recovery (when nodes fail). - * It also helps with eliminating temporary shuffle files on disk, which can be important when - * LDA is run for many iterations. - * - * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value - * given to LDA is ignored, and the existing directory is kept. - * - * (default = None) - */ - def setCheckpointDir(checkpointDir: String): this.type = { - this.checkpointDir = Some(checkpointDir) - this - } - - /** - * Clear the directory for storing checkpoint files during learning. - * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still - * occur; otherwise, no checkpointing will be used. - */ - def clearCheckpointDir(): this.type = { - this.checkpointDir = None - this - } - /** * Period (in iterations) between checkpoints. - * @see [[getCheckpointDir]] */ def getCheckpointInterval: Int = checkpointInterval /** - * Period (in iterations) between checkpoints. - * (default = 10) - * @see [[getCheckpointDir]] + * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be + * important when LDA is run for many iterations. If the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. + * + * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval @@ -269,7 +239,7 @@ class LDA private ( mode match { case LDAMode.EM => val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointDir, checkpointInterval) + checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -358,18 +328,18 @@ private[clustering] object LDA { * Vector over topics (length k) of token counts. * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. */ - type TopicCounts = BDV[Double] + private[clustering] type TopicCounts = BDV[Double] - type TokenCount = Double + private[clustering] type TokenCount = Double /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ - def term2index(term: Int): Long = -(1 + term.toLong) + private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) - def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt - def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 - def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 /** * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. @@ -381,17 +351,16 @@ private[clustering] object LDA { * @param docConcentration "alpha" * @param topicConcentration "beta" or "eta" */ - class EMOptimizer( + private[clustering] class EMOptimizer( var graph: Graph[TopicCounts, TokenCount], val k: Int, val vocabSize: Int, val docConcentration: Double, val topicConcentration: Double, - checkpointDir: Option[String], checkpointInterval: Int) { private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointDir, checkpointInterval) + graph, checkpointInterval) def next(): EMOptimizer = { val eta = topicConcentration @@ -580,7 +549,6 @@ private[clustering] object LDA { docConcentration: Double, topicConcentration: Double, randomSeed: Long, - checkpointDir: Option[String], checkpointInterval: Int): EMOptimizer = { // For each document, create an edge (Document -> Term) for each unique term in the document. val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => @@ -624,8 +592,7 @@ private[clustering] object LDA { val graph = Graph(docVertices ++ termVertices, edges) .partitionBy(PartitionStrategy.EdgePartition1D) - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir, - checkpointInterval) + new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 76672fe51e834..6e5dd119dd653 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -74,7 +74,6 @@ import org.apache.spark.storage.StorageLevel * }}} * * @param currentGraph Initial graph - * @param checkpointDir The directory for storing checkpoint files * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type @@ -83,7 +82,6 @@ import org.apache.spark.storage.StorageLevel */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( var currentGraph: Graph[VD, ED], - val checkpointDir: Option[String], val checkpointInterval: Int) extends Logging { /** FIFO queue of past checkpointed RDDs */ @@ -101,12 +99,6 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( */ private val sc = currentGraph.vertices.sparkContext - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) { - sc.setCheckpointDir(checkpointDir.get) - } - updateGraph(currentGraph) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 482dd4b272d1d..45b0154c5e4cb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import java.io.IOException + import scala.collection.mutable import scala.collection.JavaConverters._ @@ -244,7 +246,12 @@ private class RandomForest ( // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { - nodeIdCache.get.deleteAllCheckpoints() + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e:IOException => + logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") + } } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index dac28a369b5b2..699f009f0f2ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -38,7 +38,7 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10) + val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) @@ -57,9 +57,9 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString val checkpointInterval = 2 var graphsToCheck = Seq.empty[GraphToCheck] - + sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) diff --git a/pom.xml b/pom.xml index d324b5f0ec93a..aef450ae63121 100644 --- a/pom.xml +++ b/pom.xml @@ -135,8 +135,8 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 + 3.0.0.v201112011016 0.5.0 - 2.24.0 2.4.0 2.0.8 3.1.0 @@ -341,13 +341,7 @@ - - - com.esotericsoftware.kryo - kryo - ${kryo.version} - - + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 93698efe84252..f63f9c1982bb5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -374,7 +374,10 @@ object Unidoc { ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" - ) + ), + + // Group similar methods together based on the @group annotation. + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") ) } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3ac8ea597e142..e55f285a778c4 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1471,7 +1471,7 @@ def registerRDDAsTable(self, rdd, tableName): else: raise ValueError("Can only register DataFrame as table") - def parquetFile(self, path): + def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil @@ -1483,7 +1483,12 @@ def parquetFile(self, path): >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jdf = self._ssql_ctx.parquetFile(path) + gateway = self._sc._gateway + jpath = paths[0] + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) + for i in range(1, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpath, jpaths) return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c84cc95520a19..365b1685a8e71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BinaryType, BooleanType} object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -175,7 +175,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison null } else { val r = right.eval(input) - if (r == null) null else l == r + if (r == null) null + else if (left.dataType != BinaryType) l == r + else BinaryType.ordering.compare( + l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index a6d6ddd905393..91efe320546a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp +import scala.collection.mutable.ArrayBuffer import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -29,6 +30,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} @@ -159,7 +161,6 @@ object DataType { case failure: NoSuccess => throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") } - } protected[types] def buildFormattedString( @@ -227,8 +228,9 @@ abstract class DataType { def json: String = compact(render(jsonValue)) def prettyJson: String = pretty(render(jsonValue)) -} + def simpleString: String = typeName +} /** * :: DeveloperApi :: @@ -242,7 +244,6 @@ case object NullType extends DataType { override def defaultSize: Int = 1 } - protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -448,6 +449,8 @@ case object LongType extends IntegralType { * The default size of a value of the LongType is 8 bytes. */ override def defaultSize: Int = 8 + + override def simpleString = "bigint" } @@ -470,6 +473,8 @@ case object IntegerType extends IntegralType { * The default size of a value of the IntegerType is 4 bytes. */ override def defaultSize: Int = 4 + + override def simpleString = "int" } @@ -492,6 +497,8 @@ case object ShortType extends IntegralType { * The default size of a value of the ShortType is 2 bytes. */ override def defaultSize: Int = 2 + + override def simpleString = "smallint" } @@ -514,6 +521,8 @@ case object ByteType extends IntegralType { * The default size of a value of the ByteType is 1 byte. */ override def defaultSize: Int = 1 + + override def simpleString = "tinyint" } @@ -573,6 +582,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT * The default size of a value of the DecimalType is 4096 bytes. */ override def defaultSize: Int = 4096 + + override def simpleString = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal(10,0)" + } } @@ -695,6 +709,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT * (We assume that there are 100 elements). */ override def defaultSize: Int = 100 * elementType.defaultSize + + override def simpleString = s"array<${elementType.simpleString}>" } @@ -739,6 +755,57 @@ object StructType { def apply(fields: java.util.List[StructField]): StructType = { StructType(fields.toArray.asInstanceOf[Array[StructField]]) } + + private[sql] def merge(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => + ArrayType( + merge(leftElementType, rightElementType), + leftContainsNull || rightContainsNull) + + case (MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => + MapType( + merge(leftKeyType, rightKeyType), + merge(leftValueType, rightValueType), + leftContainsNull || rightContainsNull) + + case (StructType(leftFields), StructType(rightFields)) => + val newFields = ArrayBuffer.empty[StructField] + + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightFields + .find(_.name == leftName) + .map { case rightField @ StructField(_, rightType, rightNullable, _) => + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse(Some(leftField)) + .foreach(newFields += _) + } + + rightFields + .filterNot(f => leftFields.map(_.name).contains(f.name)) + .foreach(newFields += _) + + StructType(newFields) + + case (DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + + case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) + if leftUdt.userClass == rightUdt.userClass => leftUdt + + case (leftType, rightType) if leftType == rightType => + leftType + + case _ => + throw new SparkException(s"Failed to merge incompatible data types $left and $right") + } } @@ -870,6 +937,25 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * The default size of a value of the StructType is the total default sizes of all field types. */ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum + + override def simpleString = { + val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") + s"struct<${fieldTypes.mkString(",")}>" + } + + /** + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field + * B from `that`, + * + * 1. If A and B have the same name and data type, they are merged to a field C with the same name + * and data type. C is nullable if and only if either A or B is nullable. + * 2. If A doesn't exist in `that`, it's included in the result schema. + * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be + * thrown. + */ + private[sql] def merge(that: StructType): StructType = + StructType.merge(this, that).asInstanceOf[StructType] } @@ -920,6 +1006,8 @@ case class MapType( * (We assume that there are 100 elements). */ override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) + + override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index d6df927f9d42c..4911443dd6dde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -53,7 +53,9 @@ private[sql] class DataFrameImpl protected[sql]( def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { this(sqlContext, { val qe = sqlContext.executePlan(logicalPlan) - qe.analyzed // This should force analysis and throw errors if there are any + if (sqlContext.conf.dataFrameEagerAnalysis) { + qe.analyzed // This should force analysis and throw errors if there are any + } qe }) } @@ -295,7 +297,11 @@ private[sql] class DataFrameImpl protected[sql]( } override def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + if (sqlContext.conf.parquetUseDataSourceApi) { + save("org.apache.spark.sql.parquet", "path" -> path) + } else { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } } override def saveAsTable(tableName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 7fe17944a734e..180f5e765fb91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -37,6 +37,7 @@ private[spark] object SQLConf { val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" + val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" @@ -51,6 +52,9 @@ private[spark] object SQLConf { // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource" + // Whether to perform eager analysis on a DataFrame. + val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -105,6 +109,10 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetFilterPushDown = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses Parquet implementation based on data source API */ + private[spark] def parquetUseDataSourceApi = + getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean @@ -168,6 +176,9 @@ private[sql] class SQLConf extends Serializable { private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + private[spark] def dataFrameEagerAnalysis: Boolean = + getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 01620aa0acd49..706ef6ad4f174 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql import java.beans.Introspector import java.util.Properties -import scala.collection.immutable import scala.collection.JavaConversions._ +import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, Partition} import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ @@ -36,11 +35,12 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ -import org.apache.spark.sql.json._ import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.json._ +import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation, _} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.{Partition, SparkContext} /** * :: AlphaComponent :: @@ -303,8 +303,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def parquetFile(path: String): DataFrame = - DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + @scala.annotation.varargs + def parquetFile(path: String, paths: String*): DataFrame = + if (conf.parquetUseDataSourceApi) { + baseRelationToDataFrame(parquet.ParquetRelation2(path +: paths, Map.empty)(this)) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } /** * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0c77d399b2eb8..81bcf5a6f32dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources._ - +import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -337,6 +337,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") + case LogicalDescribeCommand(table, isExtended) => + val resultPlan = self.sqlContext.executePlan(table).executedPlan + ExecutedCommand( + RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil + + case LogicalDescribeCommand(table, isExtended) => + val resultPlan = self.sqlContext.executePlan(table).executedPlan + ExecutedCommand( + RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 1bc53968c4ca3..335757087deef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import scala.collection.mutable.ArrayBuffer /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -176,9 +177,14 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { @DeveloperApi case class DescribeCommand( child: SparkPlan, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute], + isExtended: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - child.output.map(field => Row(field.name, field.dataType.toString, null)) + child.schema.fields.map { field => + val cmtKey = "comment" + val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" + Row(field.name, field.dataType.simpleString, comment) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 8372decbf8aa1..f27585d05a986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType private[sql] class DefaultSource - extends RelationProvider with SchemaRelationProvider with CreateableRelationProvider { + extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { /** Returns a new base relation with the parameters. */ override def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 14c81ae4eba4e..19bfba34b8f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -159,7 +159,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -325,7 +325,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -348,10 +348,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index b646109b7c553..5209581fa8357 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,24 +19,23 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job - import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} +import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition +import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkException} // Implicits import scala.collection.JavaConversions._ @@ -285,7 +284,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ctype: DataType, name: String, nullable: Boolean = true, - inArray: Boolean = false, + inArray: Boolean = false, toThriftSchemaNames: Boolean = false): ParquetType = { val repetition = if (inArray) { @@ -340,7 +339,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } case StructType(structFields) => { val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, + field => fromDataType(field.dataType, field.name, field.nullable, inArray = false, toThriftSchemaNames) } new ParquetGroupType(repetition, name, fields.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 179c0d6b22239..49d46334b6525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -16,25 +16,38 @@ */ package org.apache.spark.sql.parquet -import java.util.{List => JList} +import java.io.IOException +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.text.SimpleDateFormat +import java.util.{List => JList, Date} import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Try -import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} import parquet.filter2.predicate.FilterApi -import parquet.hadoop.ParquetInputFormat +import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop.{ParquetInputFormat, _} import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{NewHadoopPartition, RDD} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.sql.{Row, SQLConf, SQLContext} -import org.apache.spark.{Logging, Partition => SparkPartition} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} +import org.apache.spark.sql.types.StructType._ +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} +import org.apache.spark.{Partition => SparkPartition, TaskContext, SerializableWritable, Logging, SparkException} /** @@ -43,19 +56,49 @@ import org.apache.spark.{Logging, Partition => SparkPartition} * required is `path`, which should be the location of a collection of, optionally partitioned, * parquet files. */ -class DefaultSource extends RelationProvider { +class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + } + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val path = - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) + } - ParquetRelation2(path)(sqlContext) + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = checkPath(parameters) + ParquetRelation.createEmpty( + path, + data.schema.toAttributes, + false, + sqlContext.sparkContext.hadoopConfiguration, + sqlContext) + + val relation = createRelation(sqlContext, parameters, data.schema) + relation.asInstanceOf[ParquetRelation2].insert(data, true) + relation } } -private[parquet] case class Partition(partitionValues: Map[String, Any], files: Seq[FileStatus]) +private[parquet] case class Partition(values: Row, path: String) + +private[parquet] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is @@ -81,117 +124,196 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files: * discovery. */ @DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) - extends CatalystScan with Logging { +case class ParquetRelation2 + (paths: Seq[String], parameters: Map[String, String], maybeSchema: Option[StructType] = None) + (@transient val sqlContext: SQLContext) + extends CatalystScan + with InsertableRelation + with SparkHadoopMapReduceUtil + with Logging { + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean + + // Optional Metastore schema, used when converting Hive Metastore Parquet table + private val maybeMetastoreSchema = + parameters + .get(ParquetRelation2.METASTORE_SCHEMA) + .map(s => DataType.fromJson(s).asInstanceOf[StructType]) + + // Hive uses this as part of the default partition name when the partition column value is null + // or empty string + private val defaultPartitionName = parameters.getOrElse( + ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") + + override def equals(other: Any) = other match { + case relation: ParquetRelation2 => + paths.toSet == relation.paths.toSet && + maybeMetastoreSchema == relation.maybeMetastoreSchema && + (shouldMergeSchemas == relation.shouldMergeSchemas || schema == relation.schema) + } - def sparkContext = sqlContext.sparkContext + private[sql] def sparkContext = sqlContext.sparkContext - // Minor Hack: scala doesnt seem to respect @transient for vals declared via extraction - @transient - private var partitionKeys: Seq[String] = _ - @transient - private var partitions: Seq[Partition] = _ - discoverPartitions() + @transient private val fs = FileSystem.get(sparkContext.hadoopConfiguration) - // TODO: Only finds the first partition, assumes the key is of type Integer... - private def discoverPartitions() = { - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val partValue = "([^=]+)=([^=]+)".r + private class MetadataCache { + private var metadataStatuses: Array[FileStatus] = _ + private var commonMetadataStatuses: Array[FileStatus] = _ + private var footers: Map[FileStatus, Footer] = _ + private var parquetSchema: StructType = _ - val childrenOfPath = fs.listStatus(new Path(path)).filterNot(_.getPath.getName.startsWith("_")) - val childDirs = childrenOfPath.filter(s => s.isDir) + var dataStatuses: Array[FileStatus] = _ + var partitionSpec: PartitionSpec = _ + var schema: StructType = _ + var dataSchemaIncludesPartitionKeys: Boolean = _ - if (childDirs.size > 0) { - val partitionPairs = childDirs.map(_.getPath.getName).map { - case partValue(key, value) => (key, value) + def refresh(): Unit = { + val baseStatuses = { + val statuses = paths.distinct.map(p => fs.getFileStatus(fs.makeQualified(new Path(p)))) + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + assert(statuses.forall(!_.isDir) || statuses.forall(_.isDir)) + statuses.toArray } - val foundKeys = partitionPairs.map(_._1).distinct - if (foundKeys.size > 1) { - sys.error(s"Too many distinct partition keys: $foundKeys") + val leaves = baseStatuses.flatMap { f => + val statuses = SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + } + assert(statuses.nonEmpty, s"${f.getPath} is an empty folder.") + statuses } - // Do a parallel lookup of partition metadata. - val partitionFiles = - childDirs.par.map { d => - fs.listStatus(d.getPath) - // TODO: Is there a standard hadoop function for this? - .filterNot(_.getPath.getName.startsWith("_")) - .filterNot(_.getPath.getName.startsWith(".")) - }.seq - - partitionKeys = foundKeys.toSeq - partitions = partitionFiles.zip(partitionPairs).map { case (files, (key, value)) => - Partition(Map(key -> value.toInt), files) - }.toSeq - } else { - partitionKeys = Nil - partitions = Partition(Map.empty, childrenOfPath) :: Nil - } - } + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => + val parquetMetadata = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, f, ParquetMetadataConverter.NO_FILTER) + f -> new Footer(f.getPath, parquetMetadata) + }.seq.toMap + + partitionSpec = { + val partitionDirs = dataStatuses + .filterNot(baseStatuses.contains) + .map(_.getPath.getParent) + .distinct + + if (partitionDirs.nonEmpty) { + ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName) + } else { + // No partition directories found, makes an empty specification + PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + } + } - override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum + parquetSchema = maybeSchema.getOrElse(readSchema()) - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp)) + dataSchemaIncludesPartitionKeys = + isPartitioned && + partitionColumns.forall(f => metadataCache.parquetSchema.fieldNames.contains(f.name)) - val dataIncludesKey = - partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) + schema = { + val fullParquetSchema = if (dataSchemaIncludesPartitionKeys) { + metadataCache.parquetSchema + } else { + StructType(metadataCache.parquetSchema.fields ++ partitionColumns.fields) + } - override val schema = - if (dataIncludesKey) { - dataSchema - } else { - StructType(dataSchema.fields :+ StructField(partitionKeys.head, IntegerType)) + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullParquetSchema)) + .getOrElse(fullParquetSchema) + } } - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - // This is mostly a hack so that we can use the existing parquet filter code. - val requiredColumns = output.map(_.name) + private def readSchema(): StructType = { + // Sees which file(s) we need to touch in order to figure out the schema. + val filesToTouch = + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) + ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) + } + } - val requestedSchema = StructType(requiredColumns.map(schema(_))) + @transient private val metadataCache = new MetadataCache + metadataCache.refresh() - val partitionKeySet = partitionKeys.toSet - val rawPredicate = - predicates - .filter(_.references.map(_.name).toSet.subsetOf(partitionKeySet)) - .reduceOption(And) - .getOrElse(Literal(true)) + private def partitionColumns = metadataCache.partitionSpec.partitionColumns - // Translate the predicate so that it reads from the information derived from the - // folder structure - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = partitionKeys.indexWhere(a.name == _) - BoundReference(idx, IntegerType, nullable = true) - } + private def partitions = metadataCache.partitionSpec.partitions - val inputData = new GenericMutableRow(partitionKeys.size) - val pruningCondition = InterpretedPredicate(castedPredicate) + private def isPartitioned = partitionColumns.nonEmpty - val selectedPartitions = - if (partitionKeys.nonEmpty && predicates.nonEmpty) { - partitions.filter { part => - inputData(0) = part.partitionValues.values.head - pruningCondition(inputData) - } - } else { - partitions + private def dataSchemaIncludesPartitionKeys = metadataCache.dataSchemaIncludesPartitionKeys + + override def schema = metadataCache.schema + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + // TODO Should calculate per scan size + // It's common that a query only scans a fraction of a large Parquet file. Returning size of the + // whole Parquet file disables some optimizations in this case (e.g. broadcast join). + override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum + + // This is mostly a hack so that we can use the existing parquet filter code. + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + val job = new Job(sparkContext.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val jobConf: Configuration = ContextUtil.getConfiguration(job) + + val selectedPartitions = prunePartitions(predicates, partitions) + val selectedFiles = if (isPartitioned) { + selectedPartitions.flatMap { p => + metadataCache.dataStatuses.filter(_.getPath.getParent.toString == p.path) } + } else { + metadataCache.dataStatuses.toSeq + } - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles: _*) + FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) } // Push down filters when possible. Notice that not all filters can be converted to Parquet @@ -203,23 +325,28 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) .filter(_ => sqlContext.conf.parquetFilterPushDown) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) - def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - logInfo(s"Reading $percentRead% of $path partitions") + if (isPartitioned) { + def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 + logInfo(s"Reading $percentRead% of partitions") + } + + val requiredColumns = output.map(_.name) + val requestedSchema = StructType(requiredColumns.map(schema(_))) // Store both requested and original schema in `Configuration` jobConf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedSchema.toAttributes)) + convertToString(requestedSchema.toAttributes)) jobConf.set( RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(schema.toAttributes)) + convertToString(schema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( + new NewHadoopRDD( sparkContext, classOf[FilteringParquetRowInputFormat], classOf[Void], @@ -228,66 +355,400 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) val cacheMetadata = useCache @transient - val cachedStatus = selectedPartitions.flatMap(_.files) + val cachedStatus = selectedFiles // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { - val inputFormat = - if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - } - } else { - new FilteringParquetRowInputFormat + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus } - - inputFormat match { - case configurable: Configurable => - configurable.setConf(getConf) - case _ => + } else { + new FilteringParquetRowInputFormat } + val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } - result } } - // The ordinal for the partition key in the result row, if requested. - val partitionKeyLocation = - partitionKeys - .headOption - .map(requiredColumns.indexOf(_)) - .getOrElse(-1) + // The ordinals for partition keys in the result row, if requested. + val partitionKeyLocations = partitionColumns.fieldNames.zipWithIndex.map { + case (name, index) => index -> requiredColumns.indexOf(name) + }.toMap.filter { + case (_, index) => index >= 0 + } // When the data does not include the key and the key is requested then we must fill it in // based on information from the input split. - if (!dataIncludesKey && partitionKeyLocation != -1) { - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - val currentValue = partValues.values.head.toInt - iter.map { pair => - val res = pair._2.asInstanceOf[SpecificMutableRow] - res.setInt(partitionKeyLocation, currentValue) - res + if (!dataSchemaIncludesPartitionKeys && partitionKeyLocations.nonEmpty) { + baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => + val partValues = selectedPartitions.collectFirst { + case p if split.getPath.getParent.toString == p.path => p.values + }.get + + iterator.map { pair => + val row = pair._2.asInstanceOf[SpecificMutableRow] + var i = 0 + while (i < partValues.size) { + // TODO Avoids boxing cost here! + row.update(partitionKeyLocations(i), partValues(i)) + i += 1 + } + row } } } else { baseRDD.map(_._2) } } + + private def prunePartitions( + predicates: Seq[Expression], + partitions: Seq[Partition]): Seq[Partition] = { + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + val rawPredicate = partitionPruningPredicates.reduceOption(And).getOrElse(Literal(true)) + val boundPredicate = InterpretedPredicate(rawPredicate transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + if (isPartitioned && partitionPruningPredicates.nonEmpty) { + partitions.filter(p => boundPredicate(p.values)) + } else { + partitions + } + } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val writeSupport = if (schema.map(_.dataType).forall(_.isPrimitive)) { + log.debug("Initializing MutableRowWriteSupport") + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) + + val conf = ContextUtil.getConfiguration(job) + RowWriteSupport.setSchema(schema.toAttributes, conf) + + val destinationPath = new Path(paths.head) + + if (overwrite) { + try { + destinationPath.getFileSystem(conf).delete(destinationPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet file:\n${e.toString}") + } + } + + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[Row]) + FileOutputFormat.setOutputPath(job, destinationPath) + + val wrappedConf = new SerializableWritable(job.getConfiguration) + val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()) + val stageId = sqlContext.sparkContext.newRddId() + + val taskIdOffset = if (overwrite) { + 1 + } else { + FileSystemHelper.findMaxTaskId( + FileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 + } + + def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { + /* "reduce task" */ + val attemptId = newTaskAttemptID( + jobTrackerId, stageId, isMap = false, context.partitionId(), context.attemptNumber()) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = new AppendingParquetOutputFormat(taskIdOffset) + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext) + try { + while (iterator.hasNext) { + val row = iterator.next() + writer.write(null, row) + } + } finally { + writer.close(hadoopContext) + } + committer.commitTask(hadoopContext) + } + val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + jobCommitter.setupJob(jobTaskContext) + sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) + jobCommitter.commitJob(jobTaskContext) + + metadataCache.refresh() + } +} + +object ParquetRelation2 { + // Whether we should merge schemas collected from all Parquet part-files. + val MERGE_SCHEMA = "mergeSchema" + + // Hive Metastore schema, passed in when the Parquet relation is converted from Metastore + val METASTORE_SCHEMA = "metastoreSchema" + + // Default partition name to use when the partition column value is null or empty string + val DEFAULT_PARTITION_NAME = "partition.defaultName" + + // When true, the Parquet data source caches Parquet metadata for performance + val CACHE_METADATA = "cacheMetadata" + + private[parquet] def readSchema(footers: Seq[Footer], sqlContext: SQLContext): StructType = { + footers.map { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val parquetSchema = metadata.getSchema + val maybeSparkSchema = metadata + .getKeyValueMetaData + .toMap + .get(RowReadSupport.SPARK_METADATA_KEY) + .map(DataType.fromJson(_).asInstanceOf[StructType]) + + maybeSparkSchema.getOrElse { + // Falls back to Parquet schema if Spark SQL schema is absent. + StructType.fromAttributes( + // TODO Really no need to use `Attribute` here, we only need to know the data type. + convertToAttributes( + parquetSchema, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp)) + } + }.reduce { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + assert(metastoreSchema.size == parquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + val reorderedParquetSchema = parquetSchema.sortBy(f => ordinalMap(f.name.toLowerCase)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). + // However, we are already using Catalyst expressions for partition pruning and predicate + // push-down here... + private[parquet] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + require(columnNames.size == literals.size) + } + + /** + * Given a group of qualified paths, tries to parse them and returns a partition specification. + * For example, given: + * {{{ + * hdfs://:/path/to/partition/a=1/b=hello/c=3.14 + * hdfs://:/path/to/partition/a=2/b=world/c=6.28 + * }}} + * it returns: + * {{{ + * PartitionSpec( + * partitionColumns = StructType( + * StructField(name = "a", dataType = IntegerType, nullable = true), + * StructField(name = "b", dataType = StringType, nullable = true), + * StructField(name = "c", dataType = DoubleType, nullable = true)), + * partitions = Seq( + * Partition( + * values = Row(1, "hello", 3.14), + * path = "hdfs://:/path/to/partition/a=1/b=hello/c=3.14"), + * Partition( + * values = Row(2, "world", 6.28), + * path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28"))) + * }}} + */ + private[parquet] def parsePartitions( + paths: Seq[Path], + defaultPartitionName: String): PartitionSpec = { + val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) + val fields = { + val (PartitionValues(columnNames, literals)) = partitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + StructField(name, dataType, nullable = true) + } + } + + val partitions = partitionValues.zip(paths).map { + case (PartitionValues(_, literals), path) => + Partition(Row(literals.map(_.value): _*), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } + + /** + * Parses a single partition, returns column names and values of each partition column. For + * example, given: + * {{{ + * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 + * }}} + * it returns: + * {{{ + * PartitionValues( + * Seq("a", "b", "c"), + * Seq( + * Literal(42, IntegerType), + * Literal("hello", StringType), + * Literal(3.14, FloatType))) + * }}} + */ + private[parquet] def parsePartition( + path: Path, + defaultPartitionName: String): PartitionValues = { + val columns = ArrayBuffer.empty[(String, Literal)] + // Old Hadoop versions don't have `Path.isRoot` + var finished = path.getParent == null + var chopped = path + + while (!finished) { + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + maybeColumn.foreach(columns += _) + chopped = chopped.getParent + finished = maybeColumn.isEmpty || chopped.getParent == null + } + + val (columnNames, values) = columns.reverse.unzip + PartitionValues(columnNames, values) + } + + private def parsePartitionColumn( + columnSpec: String, + defaultPartitionName: String): Option[(String, Literal)] = { + val equalSignIndex = columnSpec.indexOf('=') + if (equalSignIndex == -1) { + None + } else { + val columnName = columnSpec.take(equalSignIndex) + assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") + + val rawColumnValue = columnSpec.drop(equalSignIndex + 1) + assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") + + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + Some(columnName -> literal) + } + } + + /** + * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- + * casting order is: + * {{{ + * NullType -> + * IntegerType -> LongType -> + * FloatType -> DoubleType -> DecimalType.Unlimited -> + * StringType + * }}} + */ + private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { + val distinctColNamesOfPartitions = values.map(_.columnNames).distinct + val columnCount = values.head.columnNames.size + + // Column names of all partitions must match + assert(distinctColNamesOfPartitions.size == 1, { + val list = distinctColNamesOfPartitions.mkString("\t", "\n", "") + s"Conflicting partition column names detected:\n$list" + }) + + // Resolves possible type conflicts for each column + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } + + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } + } + + /** + * Converts a string to a `Literal` with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[StringType]]. + */ + private[parquet] def inferPartitionColumnValue( + raw: String, + defaultPartitionName: String): Literal = { + // First tries integral types + Try(Literal(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + } + } + + private val upCastingOrder: Seq[DataType] = + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + + /** + * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" + * types. + */ + private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = { + val desiredType = { + val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) + // Falls back to string if all values of this column are null or empty string + if (topType == NullType) StringType else topType + } + + literals.map { case l @ Literal(_, dataType) => + Literal(Cast(l, desiredType).eval(), desiredType) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 386ff2452f1a3..d23ffb8b7a960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, InsertIntoTable => LogicalInsertIntoTable} -import org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.{Row, Strategy, execution} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -54,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy { case l @ LogicalRelation(t: TableScan) => execution.PhysicalRDD(l.output, t.buildScan()) :: Nil - case i @ LogicalInsertIntoTable( + case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => if (partition.nonEmpty) { sys.error(s"Insert into a partition is not allowed because $l is not partitioned.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index ead827728cf4b..9c37e0169ff85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -23,6 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -50,7 +52,6 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { } } - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") @@ -61,6 +62,8 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected val EXISTS = Keyword("EXISTS") protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") protected val AS = Keyword("AS") protected val COMMENT = Keyword("COMMENT") @@ -82,7 +85,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected val MAP = Keyword("MAP") protected val STRUCT = Keyword("STRUCT") - protected lazy val ddl: Parser[LogicalPlan] = createTable + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable protected def start: Parser[LogicalPlan] = ddl @@ -136,6 +139,22 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,nullable + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { + case e ~ db ~ tbl => + val tblIdentifier = db match { + case Some(dbName) => + Seq(dbName, tbl) + case None => + Seq(tbl) + } + DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + } + protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } @@ -222,20 +241,16 @@ object ResolvedDataSource { val relation = userSpecifiedSchema match { case Some(schema: StructType) => { clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case dataSource: org.apache.spark.sql.sources.RelationProvider => sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") } } case None => { clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") } @@ -260,10 +275,8 @@ object ResolvedDataSource { } val relation = clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.CreateableRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.CreateableRelationProvider] - .createRelation(sqlContext, options, data) + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, options, data) case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") } @@ -274,6 +287,22 @@ object ResolvedDataSource { private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +private[sql] case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override def output = Seq( + // Column names are based on Hive. + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) +} + private[sql] case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], @@ -331,7 +360,7 @@ private [sql] case class CreateTempTableUsingAsSelect( /** * Builds a map in which keys are case insensitive */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ad0a35b91ebc2..40fc1f2aa2724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -78,7 +78,7 @@ trait SchemaRelationProvider { } @DeveloperApi -trait CreateableRelationProvider { +trait CreatableRelationProvider { def createRelation( sqlContext: SQLContext, parameters: Map[String, String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 74c29459d2e47..77fd3165f151f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,19 +17,23 @@ package org.apache.spark.sql +import scala.language.postfixOps + import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types._ - -/* Implicits */ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ -import scala.language.postfixOps class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ test("analysis error should be eagerly reported") { + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + // Eager analysis. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) @@ -40,6 +44,13 @@ class DataFrameSuite extends QueryTest { intercept[Exception] { testData.groupBy($"abcd").agg(Map("key" -> "sum")) } + + // No more eager analysis once the flag is turned off + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + testData.select('nonExistentName) + + // Set the flag back to original value before this test. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("table scan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a7f2faa3ecf75..f9ddd2ca5c567 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql +import java.util.{Locale, TimeZone} + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dc8ee41712fcd..11502edf972e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll @@ -37,16 +35,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { import org.apache.spark.sql.test.TestSQLContext.implicits._ - var origZone: TimeZone = _ - override protected def beforeAll() { - origZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - } - - override protected def afterAll() { - TimeZone.setDefault(origZone) - } - test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index ff91a0eb42049..f8117c21773ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -22,8 +22,10 @@ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} /** @@ -54,9 +56,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) + val maybeAnalyzedPredicate = { + val forParquetTableScan = query.queryExecution.executedPlan.collect { + case plan: ParquetTableScan => plan.columnPruningPred + }.flatten.reduceOption(_ && _) + + val forParquetDataSource = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters + }.flatten.reduceOption(_ && _) + + forParquetTableScan.orElse(forParquetDataSource) + } assert(maybeAnalyzedPredicate.isDefined) maybeAnalyzedPredicate.foreach { pred => @@ -84,213 +94,228 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } - test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - - checkFilterPredicate('_1 === true, classOf[Eq [_]], true) - checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: DataFrame): Unit = { + def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + } } + + checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) } - test("filter pushdown - short") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, - classOf[Operators.And], 3) - checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], Seq(Row(1), Row(4))) - } + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit rdd: DataFrame): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } - test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + def run(prefix: String): Unit = { + test(s"$prefix: filter pushdown - boolean") { + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + } + } + + test(s"$prefix: filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) + checkFilterPredicate( + Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate( + Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) + checkFilterPredicate( + Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], + Seq(Row(1), Row(4))) + } + } + + test(s"$prefix: filter pushdown - integer") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test(s"$prefix: filter pushdown - long") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test(s"$prefix: filter pushdown - float") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test(s"$prefix: filter pushdown - double") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + test(s"$prefix: filter pushdown - string") { + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + } } - } - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: DataFrame): Unit = { - def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { - rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + test(s"$prefix: filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") } - } - checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: DataFrame): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) - } + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - test("filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes("UTF-8") - } + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate( - '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + } } } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 0bc246c645602..c8ebbbc7d2eac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -73,218 +73,229 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } - test("basic data types (without binary)") { - val data = (1 to 4).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + def run(prefix: String): Unit = { + test(s"$prefix: basic data types (without binary)") { + val data = (1 to 4).map { i => + (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + } + checkParquetFile(data) } - checkParquetFile(data) - } - test("raw binary") { - val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => - assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + test(s"$prefix: raw binary") { + val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) + withParquetRDD(data) { rdd => + assertResult(data.map(_._1.mkString(",")).sorted) { + rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + } } } - } - - test("string") { - val data = (1 to 4).map(i => Tuple1(i.toString)) - // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL - // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) - } - test("fixed-length decimals") { - import org.apache.spark.sql.test.TestSQLContext.implicits._ - - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .select($"_1" cast decimal as "abcd") + test(s"$prefix: string") { + val data = (1 to 4).map(i => Tuple1(i.toString)) + // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL + // as we store Spark SQL schema in the extra metadata. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + } - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - withTempPath { dir => - val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + test(s"$prefix: fixed-length decimals") { + import org.apache.spark.sql.test.TestSQLContext.implicits._ + + def makeDecimalRDD(decimal: DecimalType): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(i / 100.0)) + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") + + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + } } - } - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() + } } - } - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() + } } } - } - test("map") { - val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) - checkParquetFile(data) - } + test(s"$prefix: map") { + val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) + checkParquetFile(data) + } - test("array") { - val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) - checkParquetFile(data) - } + test(s"$prefix: array") { + val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) + checkParquetFile(data) + } - test("struct") { - val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) + test(s"$prefix: struct") { + val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) + } } - } - test("nested struct with array of array as field") { - val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) + test(s"$prefix: nested struct with array of array as field") { + val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) + } } - } - test("nested map with struct as value type") { - val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => - Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) - }) + test(s"$prefix: nested map with struct as value type") { + val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) + withParquetRDD(data) { rdd => + checkAnswer(rdd, data.map { case Tuple1(m) => + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + }) + } } - } - test("nulls") { - val allNulls = ( - null.asInstanceOf[java.lang.Boolean], - null.asInstanceOf[Integer], - null.asInstanceOf[java.lang.Long], - null.asInstanceOf[java.lang.Float], - null.asInstanceOf[java.lang.Double]) - - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(5)(null): _*)) + test(s"$prefix: nulls") { + val allNulls = ( + null.asInstanceOf[java.lang.Boolean], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[java.lang.Float], + null.asInstanceOf[java.lang.Double]) + + withParquetRDD(allNulls :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(5)(null): _*)) + } } - } - test("nones") { - val allNones = ( - None.asInstanceOf[Option[Int]], - None.asInstanceOf[Option[Long]], - None.asInstanceOf[Option[String]]) + test(s"$prefix: nones") { + val allNones = ( + None.asInstanceOf[Option[Int]], + None.asInstanceOf[Option[Long]], + None.asInstanceOf[Option[String]]) - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(3)(null): _*)) + withParquetRDD(allNones :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(3)(null): _*)) + } } - } - test("compression codec") { - def compressionCodecFor(path: String) = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) - codecs.head - } + test(s"$prefix: compression codec") { + def compressionCodecFor(path: String) = { + val codecs = ParquetTypesConverter + .readMetaData(new Path(path), Some(configuration)) + .getBlocks + .flatMap(_.getColumns) + .map(_.getCodec.name()) + .distinct + + assert(codecs.size === 1) + codecs.head + } - val data = (0 until 10).map(i => (i, i.toString)) + val data = (0 until 10).map(i => (i, i.toString)) - def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { - withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) + def checkCompressionCodec(codec: CompressionCodecName): Unit = { + withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withParquetFile(data) { path => + assertResult(conf.parquetCompressionCodec.toUpperCase) { + compressionCodecFor(path) + } } } } - } - // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + // Checks default compression codec + checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) - checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) - checkCompressionCodec(CompressionCodecName.GZIP) - checkCompressionCodec(CompressionCodecName.SNAPPY) - } + checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) + checkCompressionCodec(CompressionCodecName.GZIP) + checkCompressionCodec(CompressionCodecName.SNAPPY) + } - test("read raw Parquet file") { - def makeRawParquetFile(path: Path): Unit = { - val schema = MessageTypeParser.parseMessageType( - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - |} - """.stripMargin) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - (0 until 10).foreach { i => - val record = new SimpleGroup(schema) - record.add(0, i % 2 == 0) - record.add(1, i) - record.add(2, i.toLong) - record.add(3, i.toFloat) - record.add(4, i.toDouble) - writer.write(record) - } + test(s"$prefix: read raw Parquet file") { + def makeRawParquetFile(path: Path): Unit = { + val schema = MessageTypeParser.parseMessageType( + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + |} + """.stripMargin) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + (0 until 10).foreach { i => + val record = new SimpleGroup(schema) + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + writer.write(record) + } - writer.close() - } + writer.close() + } - withTempDir { dir => - val path = new Path(dir.toURI.toString, "part-r-0.parquet") - makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) + } } - } - test("write metadata") { - withTempPath { file => - val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + test(s"$prefix: write metadata") { + withTempPath { file => + val path = new Path(file.toURI.toString) + val fs = FileSystem.getLocal(configuration) + val attributes = ScalaReflection.attributesFor[(Int, String)] + ParquetTypesConverter.writeMetaData(attributes, path, configuration) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val actualSchema = metaData.getFileMetaData.getSchema + val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) - actualSchema.checkContains(expectedSchema) - expectedSchema.checkContains(actualSchema) + actualSchema.checkContains(expectedSchema) + expectedSchema.checkContains(actualSchema) + } } } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000..ae606d11a8f68 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parquet + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.parquet.ParquetRelation2._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { + override val sqlContext: SQLContext = TestSQLContext + + val defaultPartitionName = "__NULL__" + + test("column type inference") { + def check(raw: String, literal: Literal): Unit = { + assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + } + + check("10", Literal(10, IntegerType)) + check("1000000000000000", Literal(1000000000000000L, LongType)) + check("1.5", Literal(1.5, FloatType)) + check("hello", Literal("hello", StringType)) + check(defaultPartitionName, Literal(null, NullType)) + } + + test("parse partition") { + def check(path: String, expected: PartitionValues): Unit = { + assert(expected === parsePartition(new Path(path), defaultPartitionName)) + } + + def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { + val message = intercept[T] { + parsePartition(new Path(path), defaultPartitionName) + }.getMessage + + assert(message.contains(expected)) + } + + check( + "file:///", + PartitionValues( + ArrayBuffer.empty[String], + ArrayBuffer.empty[Literal])) + + check( + "file://path/a=10", + PartitionValues( + ArrayBuffer("a"), + ArrayBuffer(Literal(10, IntegerType)))) + + check( + "file://path/a=10/b=hello/c=1.5", + PartitionValues( + ArrayBuffer("a", "b", "c"), + ArrayBuffer( + Literal(10, IntegerType), + Literal("hello", StringType), + Literal(1.5, FloatType)))) + + check( + "file://path/a=10/b_hello/c=1.5", + PartitionValues( + ArrayBuffer("c"), + ArrayBuffer(Literal(1.5, FloatType)))) + + checkThrows[AssertionError]("file://path/=10", "Empty partition column name") + checkThrows[AssertionError]("file://path/a=", "Empty partition column value") + } + + test("parse partitions") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 5ec7a156d9353..48c7598343e55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{QueryTest, SQLConf} /** * A test suite that tests various Parquet queries. @@ -28,82 +28,93 @@ import org.apache.spark.sql.test.TestSQLContext._ class ParquetQuerySuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - test("simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) + def run(prefix: String): Unit = { + test(s"$prefix: simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) + } } - } - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + // TODO Re-enable this after data source insertion API is merged + test(s"$prefix: appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } } - } - // This test case will trigger the NPE mentioned in - // https://issues.apache.org/jira/browse/PARQUET-151. - ignore("overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM t") - checkAnswer(table("t"), data.map(Row.fromTuple)) + // This test case will trigger the NPE mentioned in + // https://issues.apache.org/jira/browse/PARQUET-151. + ignore(s"$prefix: overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM t") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } } - } - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } + test(s"$prefix: self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } + assertResult(4, s"Field count mismatches")(queryOutput.size) + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } } - } - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) + test(s"$prefix: nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } } - } - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) + test(s"$prefix: nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } } - } - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + test(s"$prefix: SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + } } - } - test("SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + test(s"$prefix: SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } } } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 5f7f31d395cf7..2e6c2d5f9ab55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ class ParquetSchemaSuite extends FunSuite with ParquetTest { val sqlContext = TestSQLContext @@ -192,4 +193,40 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { assert(a.nullable === b.nullable) } } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Conflicting field count + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting field names + intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala new file mode 100644 index 0000000000000..0ec756bfeb7ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -0,0 +1,99 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + } +} + +case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = + StructType(Seq( + StructField("intType", IntegerType, nullable = false, + new MetadataBuilder().putString("comment", "test comment").build()), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1",StringType) :: + (StructField("f2",IntegerType)) :: Nil + ) + ) + )) + + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to). + map(e => Row(s"people$e", e * 2)) +} + +class DDLTestSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + + sqlTest( + "describe ddlPeople", + Seq( + Row("intType", "int", "test comment"), + Row("stringType", "string", ""), + Row("dateType", "date", ""), + Row("timestampType", "timestamp", ""), + Row("doubleType", "double", ""), + Row("bigintType", "bigint", ""), + Row("tinyintType", "tinyint", ""), + Row("decimalType", "decimal(10,0)", ""), + Row("fixedDecimalType", "decimal(5,1)", ""), + Row("binaryType", "binary", ""), + Row("booleanType", "boolean", ""), + Row("smallIntType", "smallint", ""), + Row("floatType", "float", ""), + Row("mapType", "map", ""), + Row("arrayType", "array", ""), + Row("structType", "struct", "") + )) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1921bf6e5e1a6..d2371d4a5519e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -75,7 +75,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { DataFrame(this, ddlParser(sqlText, exceptionOnError = false).getOrElse(HiveQl.parseSql(substituted))) } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") + sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 243310686d08a..c78369d12cf55 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -175,10 +176,25 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Nil } - // Since HiveQL is case insensitive for table names we make them all lowercase. - MetastoreRelation( + val relation = MetastoreRelation( databaseName, tblName, alias)( table.getTTable, partitions.map(part => part.getTPartition))(hive) + + if (hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet")) { + val metastoreSchema = StructType.fromAttributes(relation.output) + val paths = if (relation.hiveQlTable.isPartitioned) { + relation.hiveQlPartitions.map(p => p.getLocation) + } else { + Seq(relation.hiveQlTable.getDataLocation.toString) + } + + LogicalRelation(ParquetRelation2( + paths, Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) + } else { + relation + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 62e9d92eac076..c19a091719190 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ @@ -47,22 +48,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -/** - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. - * @param table The table to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. - * It is effective only when the table is a Hive table. - */ -case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends Command { - override def output = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false)(), - AttributeReference("data_type", StringType, nullable = false)(), - AttributeReference("comment", StringType, nullable = false)()) -} - /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d89111094b9ff..95abc363ae767 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ @@ -86,7 +87,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet => + hiveContext.convertMetastoreParquet && + !hiveContext.conf.parquetUseDataSourceApi => // Filter out all predicates that only deal with partition keys val partitionsKeys = AttributeSet(relation.partitionKeys) @@ -135,8 +137,10 @@ private[hive] trait HiveStrategies { pruningCondition(inputData) } + val partitionLocations = partitions.map(_.getLocation) + hiveContext - .parquetFile(partitions.map(_.getLocation).mkString(",")) + .parquetFile(partitionLocations.head, partitionLocations.tail: _*) .addPartitioningAttributes(relation.partitionKeys) .lowerCase .where(unresolvedOtherPredicates) @@ -240,8 +244,11 @@ private[hive] trait HiveStrategies { case t: MetastoreRelation => ExecutedCommand( DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil + case o: LogicalPlan => - ExecutedCommand(RunnableDescribeCommand(planLater(o), describe.output)) :: Nil + val resultPlan = context.executePlan(o).executedPlan + ExecutedCommand(RunnableDescribeCommand( + resultPlan, describe.output, describe.isExtended)) :: Nil } case _ => Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index f8a957d55d57e..a90bd1e257ade 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,8 +22,8 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.Logging +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} -import org.apache.spark.sql.hive.DescribeCommand import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 407d6058c33ed..bb73ff1ea7e43 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -741,8 +741,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertResult( Array( - Row("a", "IntegerType", null), - Row("b", "StringType", null)) + Row("a", "int", ""), + Row("b", "string", "")) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 22310ffadd25e..49fe79d989259 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest - -import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.{QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -109,28 +106,34 @@ class SQLQuerySuite extends QueryTest { ) if (HiveShim.version =="0.13.1") { - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") - checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 581f666399492..eae69af5864aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -28,53 +28,55 @@ class HiveParquetSuite extends QueryTest with ParquetTest { import sqlContext._ - test("Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) + def run(prefix: String): Unit = { + test(s"$prefix: Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) + } } - } - test("SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } - } - - test("Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { - checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) + test(s"$prefix: SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) + } } - } - test("Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { + test(s"$prefix: Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) } } - } - - test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") + test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + checkAnswer( + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) + } + } + } + + // TODO Re-enable this after data source insertion API is merged + ignore(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) + parquetFile(file.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index 30441bbbdf817..a7479a5b95864 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -23,7 +23,8 @@ import java.io.File import org.apache.spark.sql.catalyst.expressions.Row import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ @@ -79,7 +80,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + location '${new File(normalTableDir, "normal").getCanonicalPath}' """) (1 to 10).foreach { p => @@ -97,7 +98,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { setConf("spark.sql.hive.convertMetastoreParquet", "false") } - test("conversion is working") { + test(s"conversion is working") { assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: HiveTableScan => true @@ -105,6 +106,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: ParquetTableScan => true + case _: PhysicalRDD => true }.nonEmpty) } } @@ -147,6 +149,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { */ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { var partitionedTableDir: File = null + var normalTableDir: File = null var partitionedTableDirWithKey: File = null import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -156,6 +159,10 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll partitionedTableDir.delete() partitionedTableDir.mkdir() + normalTableDir = File.createTempFile("parquettests", "sparksql") + normalTableDir.delete() + normalTableDir.mkdir() + (1 to 10).foreach { p => val partDir = new File(partitionedTableDir, s"p=$p") sparkContext.makeRDD(1 to 10) @@ -163,6 +170,11 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll .saveAsParquetFile(partDir.getCanonicalPath) } + sparkContext + .makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-1")) + .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") partitionedTableDirWithKey.delete() partitionedTableDirWithKey.mkdir() @@ -175,99 +187,107 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll } } - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) + def run(prefix: String): Unit = { + Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"$prefix: ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"$prefix: project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"$prefix: project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"$prefix: simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) + } + + test(s"$prefix: pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + Row(10)) + } + + test(s"$prefix: non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"$prefix: multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"$prefix: non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } + + test(s"$prefix: sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"$prefix: hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } } - test(s"pruned count $table") { + test(s"$prefix: $prefix: non-part select(*)") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + sql("SELECT COUNT(*) FROM normal_parquet"), Row(10)) } - - test(s"non-existant partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } } - test("non-part select(*)") { - checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), - Row(10)) - } + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + run("Parquet data source enabled") + + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + run("Parquet data source disabled") }