Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19085][SQL] cleanup OutputWriterFactory and OutputWriter #16479

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter(

private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))

override def write(row: Row): Unit = {
val label = row.get(0)
val vector = row.get(1).asInstanceOf[Vector]
// This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LibSVMFileFormat.verifySchema is only called in the buildReader , but this is the write path, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I added the verification.

private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT]

override def write(row: InternalRow): Unit = {
val label = row.getDouble(0)
val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length))
writer.write(label.toString)
vector.foreachActive { case (i, v) =>
writer.write(s" ${i + 1}:$v")
Expand Down Expand Up @@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
new OutputWriterFactory {
override def newInstance(
path: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.ml.source.libsvm

import java.io.File
import java.io.{File, IOException}
import java.nio.charset.StandardCharsets

import com.google.common.io.Files

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
Expand Down Expand Up @@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {

test("write libsvm data failed due to invalid schema") {
val df = spark.read.format("text").load(path)
intercept[SparkException] {
intercept[IOException] {
df.write.format("libsvm").save(path + "_2")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ case class DataSource(
// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val columns = partitionColumns.map { name =>
val partitionAttributes = partitionColumns.map { name =>
val plan = data.logicalPlan
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
Expand All @@ -485,7 +485,7 @@ case class DataSource(
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
partitionColumns = columns,
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ object FileFormatWriter extends Logging {
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val nonPartitionColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename nonPartitionColumns to dataColumns, to be consistent with other places in the codebase.

val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
s"""
|All columns: ${allColumns.mkString(", ")}
|Partition columns: ${partitionColumns.mkString(", ")}
|Non-partition columns: ${nonPartitionColumns.mkString(", ")}
|Data columns: ${dataColumns.mkString(", ")}
""".stripMargin)
}

Expand Down Expand Up @@ -120,7 +120,7 @@ object FileFormatWriter extends Logging {
outputWriterFactory = outputWriterFactory,
allColumns = queryExecution.logical.output,
partitionColumns = partitionColumns,
nonPartitionColumns = dataColumns,
dataColumns = dataColumns,
bucketSpec = bucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
Expand Down Expand Up @@ -246,9 +246,8 @@ object FileFormatWriter extends Logging {

currentWriter = description.outputWriterFactory.newInstance(
path = tmpFilePath,
dataSchema = description.nonPartitionColumns.toStructType,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
Expand All @@ -267,7 +266,7 @@ object FileFormatWriter extends Logging {
}

val internalRow = iter.next()
currentWriter.writeInternal(internalRow)
currentWriter.write(internalRow)
recordsInFile += 1
}
releaseResources()
Expand Down Expand Up @@ -364,9 +363,8 @@ object FileFormatWriter extends Logging {

currentWriter = description.outputWriterFactory.newInstance(
path = path,
dataSchema = description.nonPartitionColumns.toStructType,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
currentWriter.initConverter(description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
Expand All @@ -383,7 +381,7 @@ object FileFormatWriter extends Logging {

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
description.nonPartitionColumns, description.allColumns)
description.dataColumns, description.allColumns)

// Returns the partition path given a partition key.
val getPartitionStringFunc = UnsafeProjection.create(
Expand All @@ -392,7 +390,7 @@ object FileFormatWriter extends Logging {
// Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(description.nonPartitionColumns),
StructType.fromAttributes(description.dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
Expand Down Expand Up @@ -448,7 +446,7 @@ object FileFormatWriter extends Logging {
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}

currentWriter.writeInternal(sortedIterator.getValue)
currentWriter.write(sortedIterator.getValue)
recordsInFile += 1
}
releaseResources()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class InsertIntoHadoopFsRelationCommand(
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
@transient query: LogicalPlan,
query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter

/**
* Returns a new instance of [[OutputWriter]] that will write data to the given path.
* This method gets called by each task on executor to write InternalRows to
* format-specific files. Compared to the other `newInstance()`, this is a newer API that
* passes only the path that the writer must write to. The writer must write to the exact path
* and not modify it (do not add subdirectories, extensions, etc.). All other
* file-format-specific information needed to create the writer must be passed
* through the [[OutputWriterFactory]] implementation.
*/
def newWriter(path: String): OutputWriter = {
throw new UnsupportedOperationException("newInstance with just path not supported")
}
}


Expand All @@ -74,22 +61,11 @@ abstract class OutputWriter {
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
* tables, dynamic partition columns are not included in rows to be written.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of this function was removed in https://github.com/apache/spark/pull/15710/files

I think it is safe to remove it.

*/
def write(row: Row): Unit
def write(row: InternalRow): Unit

/**
* Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before
* the task output is committed.
*/
def close(): Unit

private var converter: InternalRow => Row = _

protected[sql] def initConverter(dataSchema: StructType) = {
converter =
CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
}

protected[sql] def writeInternal(row: InternalRow): Unit = {
write(converter(row))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter(
row.get(ordinal, dt).toString
}

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
csvWriter.writeRow(rowToString(row), printHeader)
printHeader = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ private[json] class JsonOutputWriter(
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
gen.write(row)
gen.writeLineEnding()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon
}.getRecordWriter(context)
}

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
override def write(row: InternalRow): Unit = recordWriter.write(null, row)

override def close(): Unit = recordWriter.close(context)
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ class TextOutputWriter(

private val writer = CodecStreams.createOutputStream(context, new Path(path))

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
if (!row.isNullAt(0)) {
val utf8string = row.getUTF8String(0)
utf8string.writeTo(writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter(
).asInstanceOf[RecordWriter[NullWritable, Writable]]
}

override def write(row: Row): Unit =
throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
recordWriter.write(NullWritable.get(), serializer.serialize(row))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

Expand All @@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) {
new SimpleTextOutputWriter(path, dataSchema, context) {
var failed = false
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
failed = true
SimpleTextRelation.callbackCalled = true
}

override def write(row: Row): Unit = {
override def write(row: InternalRow): Unit = {
if (SimpleTextRelation.failWriter) {
sys.error("Intentional task writer failure for testing purpose.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context)
new SimpleTextOutputWriter(path, dataSchema, context)
}

override def getFileExtension(context: TaskAttemptContext): String = ""
Expand Down Expand Up @@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
}
}

class SimpleTextOutputWriter(path: String, context: TaskAttemptContext)
class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext)
extends OutputWriter {

private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))

override def write(row: Row): Unit = {
val serialized = row.toSeq.map { v =>
override def write(row: InternalRow): Unit = {
val serialized = row.toSeq(dataSchema).map { v =>
if (v == null) "" else v.toString
}.mkString(",")

Expand Down