Skip to content

Commit

Permalink
Removes FileFormat.prepareRead
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Jun 16, 2016
1 parent 7c6c692 commit eeb8d52
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "LibSVM"

private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 ||
(!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) {
if (
dataSchema.size != 2 ||
!dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
!dataSchema(1).dataType.sameType(new VectorUDT()) ||
!(dataSchema(1).metadata.getLong("numFeatures").toInt > 0)
) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
Expand All @@ -131,17 +134,8 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Some(
StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil))
}

override def prepareRead(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = {
val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse {
// Infers number of features if the user doesn't specify (a valid) one.
val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) {
dataFiles.head.getPath.toUri.toString
Expand All @@ -156,7 +150,14 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
MLUtils.computeNumFeatures(parsed)
}

new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
val featuresMetadata = new MetadataBuilder()
.putLong("numFeatures", numFeatures)
.build()

Some(
StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil))
}

override def prepareWrite(
Expand Down Expand Up @@ -185,7 +186,7 @@ class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
val numFeatures = options("numFeatures").toInt
val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt
assert(numFeatures > 0)

val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,14 @@ case class DataSource(
"It must be specified manually")
}

val enrichedOptions =
format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles())

HadoopFsRelation(
sparkSession,
fileCatalog,
partitionSchema = fileCatalog.partitionSpec().partitionColumns,
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
enrichedOptions)
caseInsensitiveOptions)

case _ =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompres
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
Expand Down Expand Up @@ -186,15 +185,6 @@ trait FileFormat {
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType]

/**
* Prepares a read job and returns a potentially updated data source option [[Map]]. This method
* can be useful for collecting necessary global information for scanning input data.
*/
def prepareRead(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = options

/**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here
Expand Down Expand Up @@ -454,7 +444,6 @@ private[sql] object HadoopFsRelation extends Logging {
logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")

val sparkContext = sparkSession.sparkContext
val sqlConf = sparkSession.sessionState.conf
val serializableConfiguration = new SerializableConfiguration(hadoopConf)
val serializedPaths = paths.map(_.toString)

Expand Down

0 comments on commit eeb8d52

Please sign in to comment.