Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-15983][SQL] Removes FileFormat.prepareRead #13698

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "LibSVM"

private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 ||
(!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) {
if (
dataSchema.size != 2 ||
!dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
!dataSchema(1).dataType.sameType(new VectorUDT()) ||
!(dataSchema(1).metadata.getLong("numFeatures").toInt > 0)
) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
Expand All @@ -131,17 +134,8 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Some(
StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil))
}

override def prepareRead(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = {
val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse {
// Infers number of features if the user doesn't specify (a valid) one.
val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) {
dataFiles.head.getPath.toUri.toString
Expand All @@ -156,7 +150,14 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
MLUtils.computeNumFeatures(parsed)
}

new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
val featuresMetadata = new MetadataBuilder()
.putLong("numFeatures", numFeatures)
.build()

Some(
StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil))
}

override def prepareWrite(
Expand Down Expand Up @@ -185,7 +186,7 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
val numFeatures = options("numFeatures").toInt
val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt
assert(numFeatures > 0)

val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,14 @@ case class DataSource(
"It must be specified manually")
}

val enrichedOptions =
format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles())

HadoopFsRelation(
sparkSession,
fileCatalog,
partitionSchema = fileCatalog.partitionSpec().partitionColumns,
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
enrichedOptions)
caseInsensitiveOptions)

case _ =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompres
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
Expand Down Expand Up @@ -186,15 +185,6 @@ trait FileFormat {
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType]

/**
* Prepares a read job and returns a potentially updated data source option [[Map]]. This method
* can be useful for collecting necessary global information for scanning input data.
*/
def prepareRead(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = options

/**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here
Expand Down Expand Up @@ -454,7 +444,6 @@ private[sql] object HadoopFsRelation extends Logging {
logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")

val sparkContext = sparkSession.sparkContext
val sqlConf = sparkSession.sessionState.conf
val serializableConfiguration = new SerializableConfiguration(hadoopConf)
val serializedPaths = paths.map(_.toString)

Expand Down