Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18362][SQL] Use TextFileFormat in implementation of CSVFileFormat #15813

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ import org.apache.hadoop.mapreduce._

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.functions.{length, trim}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -52,17 +54,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
require(files.nonEmpty, "Cannot infer schema from an empty set of files")
val csvOptions = new CSVOptions(options)

// TODO: Move filtering.
val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString)
val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
val firstLine: String = findFirstLine(csvOptions, lines)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)

val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
lines,
firstLine = if (csvOptions.headerFlag) firstLine else null,
params = csvOptions)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
Expand Down Expand Up @@ -173,51 +179,37 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}

private def baseRdd(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): RDD[String] = {
readText(sparkSession, options, inputPaths.mkString(","))
}

private def tokenRdd(
sparkSession: SparkSession,
options: CSVOptions,
header: Array[String],
inputPaths: Seq[String]): RDD[Array[String]] = {
val rdd = baseRdd(sparkSession, options, inputPaths)
// Make sure firstLine is materialized before sending to executors
val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
CSVRelation.univocityTokenizer(rdd, firstLine, options)
}

/**
* Returns the first line of the first non-empty file in path
*/
private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = {
private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
import lines.sqlContext.implicits._
val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
if (options.isCommentSet) {
val comment = options.comment.toString
rdd.filter { line =>
line.trim.nonEmpty && !line.startsWith(comment)
}.first()
nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
} else {
rdd.filter { line =>
line.trim.nonEmpty
}.first()
nonEmptyLines.first()
}
}

private def readText(
sparkSession: SparkSession,
options: CSVOptions,
location: String): RDD[String] = {
inputPaths: Seq[String]): Dataset[String] = {
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
sparkSession.sparkContext.textFile(location)
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = inputPaths,
className = classOf[TextFileFormat].getName
).resolveRelation(checkFilesExist = false))
.select("value").as[String](Encoders.STRING)
Copy link
Member

@HyukjinKwon HyukjinKwon Nov 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JoshRosen, I just happened to look at this one and I am just curious. IIUC, the schema from the sparkSession.baseRelationToDataFrame will always has only value column not including partitioned columns (it is empty and also inputPaths will be always leaf files).

So, my question is, is that .select("value") used just to doubly make sure? Just curious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this logic from the text method in DataFrameReader, so that's where the value came from.

} else {
val charset = options.charset
sparkSession.sparkContext
.hadoopFile[LongWritable, Text, TextInputFormat](location)
val rdd = sparkSession.sparkContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoshRosen do you know why the special handling for non-utf8 encoding is needed? I would think TextFileFormat itself already supports that since it is reading it in from Hadoop Text.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure; I think this was a carryover from spark-csv.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @falaki
Can you chime in?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin, I made a PR to address it at #29063 FYI.

.hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(","))
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
sparkSession.createDataset(rdd)(Encoders.STRING)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ import org.apache.spark.sql.types._
object CSVRelation extends Logging {

def univocityTokenizer(
file: RDD[String],
file: Dataset[String],
firstLine: String,
params: CSVOptions): RDD[Array[String]] = {
// If header is set, make sure firstLine is materialized before sending to executors.
val commentPrefix = params.comment.toString
file.mapPartitions { iter =>
file.rdd.mapPartitions { iter =>
val parser = new CsvReader(params)
val filteredIter = iter.filter { line =>
line.trim.nonEmpty && !line.startsWith(commentPrefix)
Expand Down