diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index a3691158ee758..e627f040d3cc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,10 +27,12 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.functions.{length, trim} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -52,17 +54,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + require(files.nonEmpty, "Cannot infer schema from an empty set of files") val csvOptions = new CSVOptions(options) // TODO: Move filtering. val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString) - val rdd = baseRdd(sparkSession, csvOptions, paths) - val firstLine = findFirstLine(csvOptions, rdd) + val lines: Dataset[String] = readText(sparkSession, csvOptions, paths) + val firstLine: String = findFirstLine(csvOptions, lines) val firstRow = new CsvReader(csvOptions).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) - val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) + val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer( + lines, + firstLine = if (csvOptions.headerFlag) firstLine else null, + params = csvOptions) val schema = if (csvOptions.inferSchemaFlag) { CSVInferSchema.infer(parsedRdd, header, csvOptions) } else { @@ -173,51 +179,37 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } } - private def baseRdd( - sparkSession: SparkSession, - options: CSVOptions, - inputPaths: Seq[String]): RDD[String] = { - readText(sparkSession, options, inputPaths.mkString(",")) - } - - private def tokenRdd( - sparkSession: SparkSession, - options: CSVOptions, - header: Array[String], - inputPaths: Seq[String]): RDD[Array[String]] = { - val rdd = baseRdd(sparkSession, options, inputPaths) - // Make sure firstLine is materialized before sending to executors - val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null - CSVRelation.univocityTokenizer(rdd, firstLine, options) - } - /** * Returns the first line of the first non-empty file in path */ - private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = { + private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = { + import lines.sqlContext.implicits._ + val nonEmptyLines = lines.filter(length(trim($"value")) > 0) if (options.isCommentSet) { - val comment = options.comment.toString - rdd.filter { line => - line.trim.nonEmpty && !line.startsWith(comment) - }.first() + nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first() } else { - rdd.filter { line => - line.trim.nonEmpty - }.first() + nonEmptyLines.first() } } private def readText( sparkSession: SparkSession, options: CSVOptions, - location: String): RDD[String] = { + inputPaths: Seq[String]): Dataset[String] = { if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sparkSession.sparkContext.textFile(location) + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = inputPaths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) } else { val charset = options.charset - sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](location) + val rdd = sparkSession.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + sparkSession.createDataset(rdd)(Encoders.STRING) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 52de11d403446..e4ce7a94be7df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -34,12 +34,12 @@ import org.apache.spark.sql.types._ object CSVRelation extends Logging { def univocityTokenizer( - file: RDD[String], + file: Dataset[String], firstLine: String, params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. val commentPrefix = params.comment.toString - file.mapPartitions { iter => + file.rdd.mapPartitions { iter => val parser = new CsvReader(params) val filteredIter = iter.filter { line => line.trim.nonEmpty && !line.startsWith(commentPrefix)