Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Mar 29, 2017
1 parent 746a558 commit a541fdd
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ abstract class FileCommitProtocol {
def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = {
fs.delete(path, recursive)
}

/**
* Called on the driver after a task commits. This can be used to access task commit messages
* before the job has finished. These same task commit messages will be passed to commitJob()
* if the entire job succeeds.
*/
def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ object FileFormatWriter extends Logging {
""".stripMargin)
}

/** The result of a successful write task. */
private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String])

/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
Expand Down Expand Up @@ -172,8 +175,9 @@ object FileFormatWriter extends Logging {
global = false,
child = queryExecution.executedPlan).execute()
}

val ret = sparkSession.sparkContext.runJob(rdd,
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
Expand All @@ -182,10 +186,16 @@ object FileFormatWriter extends Logging {
sparkAttemptNumber = taskContext.attemptNumber(),
committer,
iterator = iter)
},
0 until rdd.partitions.length,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
})

val commitMsgs = ret.map(_._1)
val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment)
val commitMsgs = ret.map(_.commitMsg)
val updatedPartitions = ret.flatMap(_.updatedPartitions)
.distinct.map(PartitioningUtils.parsePathFragment)

committer.commitJob(job, commitMsgs)
logInfo(s"Job ${job.getJobID} committed.")
Expand All @@ -205,7 +215,7 @@ object FileFormatWriter extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int,
committer: FileCommitProtocol,
iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = {
iterator: Iterator[InternalRow]): WriteTaskResult = {

val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
Expand Down Expand Up @@ -238,7 +248,7 @@ object FileFormatWriter extends Logging {
// Execute the task to write rows out and commit the task.
val outputPartitions = writeTask.execute(iterator)
writeTask.releaseResources()
(committer.commitTask(taskAttemptContext), outputPartitions)
WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions)
})(catchBlock = {
// If there is an error, release resource and then abort the task
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.spark.sql.test

import java.io.File
import java.util.concurrent.ConcurrentLinkedQueue

import org.scalatest.BeforeAndAfter

import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
Expand All @@ -41,7 +44,6 @@ object LastOptions {
}
}


/** Dummy provider. */
class DefaultSource
extends RelationProvider
Expand Down Expand Up @@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema
}
}

object MessageCapturingCommitProtocol {
val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]()
}

class MessageCapturingCommitProtocol(jobId: String, path: String)
extends HadoopMapReduceCommitProtocol(jobId, path) {

// captures commit messages for testing
override def onTaskCommit(msg: TaskCommitMessage): Unit = {
MessageCapturingCommitProtocol.commitMessages.offer(msg)
}
}


class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._

Expand Down Expand Up @@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
}

test("write path implements onTaskCommit API correctly") {
withSQLConf(
"spark.sql.sources.commitProtocolClass" ->
classOf[MessageCapturingCommitProtocol].getCanonicalName) {
withTempDir { dir =>
val path = dir.getCanonicalPath
MessageCapturingCommitProtocol.commitMessages.clear()
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
assert(MessageCapturingCommitProtocol.commitMessages.size() == 10)
}
}
}

test("read a data source that does not extend SchemaRelationProvider") {
val dfReader = spark.read
.option("from", "1")
Expand Down

0 comments on commit a541fdd

Please sign in to comment.