Skip to content

Commit

Permalink
Bunch of fixes for longer running jobs
Browse files Browse the repository at this point in the history
1. Increase the timeout for socket connection to wait for long jobs
2. Add some profiling information in worker.R
3. Put temp file writes before stdin writes in RRDD.scala
  • Loading branch information
shivaram committed Feb 1, 2015
1 parent 227ee42 commit 179aa75
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pkg/R/sparkRClient.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout = 60) {
connectBackend <- function(hostname, port, timeout = 6000) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
cat("SparkRBackend client connection already exists\n")
return(get(".sparkRcon", envir = .sparkREnv))
Expand Down
19 changes: 19 additions & 0 deletions pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Worker class

begin <- proc.time()[3]

# NOTE: We use "stdin" to get the process stdin instead of the command line
inputConStdin <- file("stdin", open = "rb")

Expand Down Expand Up @@ -65,6 +67,8 @@ numPartitions <- SparkR:::readInt(inputCon)

isEmpty <- SparkR:::readInt(inputCon)

metadataEnd <- proc.time()[3]

if (isEmpty != 0) {

if (numPartitions == -1) {
Expand All @@ -74,12 +78,15 @@ if (isEmpty != 0) {
} else {
data <- readLines(inputCon)
}
dataReadEnd <- proc.time()[3]
output <- do.call(execFunctionName, list(splitIndex, data))
computeEnd <- proc.time()[3]
if (isOutputSerialized) {
SparkR:::writeRawSerialize(outputCon, output)
} else {
SparkR:::writeStrings(outputCon, output)
}
writeEnd <- proc.time()[3]
} else {
if (isInputSerialized) {
# Now read as many characters as described in funcLen
Expand All @@ -88,6 +95,7 @@ if (isEmpty != 0) {
data <- readLines(inputCon)
}

dataReadEnd <- proc.time()[3]
res <- new.env()

# Step 1: hash the data to an environment
Expand All @@ -105,6 +113,8 @@ if (isEmpty != 0) {
}
invisible(lapply(data, hashTupleToEnvir))

computeEnd <- proc.time()[3]

# Step 2: write out all of the environment as key-value pairs.
for (name in ls(res)) {
SparkR:::writeInt(outputCon, 2L)
Expand All @@ -113,6 +123,7 @@ if (isEmpty != 0) {
length(res[[name]]$data) <- res[[name]]$counter
SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
}
writeEnd <- proc.time()[3]
}
}

Expand All @@ -128,5 +139,13 @@ unlink(inFileName)
# Restore stdout
sink()

end <- proc.time()[3]

cat("stats: total ", (end-begin), "\n", file=stderr())
cat("stats: metadata ", (metadataEnd-begin), "\n", file=stderr())
cat("stats: input read ", (dataReadEnd-metadataEnd), "\n", file=stderr())
cat("stats: compute ", (computeEnd-dataReadEnd), "\n", file=stderr())
cat("stats: output write ", (writeEnd-computeEnd), "\n", file=stderr())

# Finally print the name of the output file
cat(outputFileName, "\n")
123 changes: 66 additions & 57 deletions pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,68 +117,77 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for R") {
override def run() {
SparkEnv.set(env)
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
val printOutStd = new PrintStream(streamStd)
printOutStd.println(tempFileName)
printOutStd.println(rLibDir)
printOutStd.println(tempFileIn.getAbsolutePath())
printOutStd.flush()

streamStd.close()

val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
val printOut = new PrintStream(stream)
val dataOut = new DataOutputStream(stream)

dataOut.writeInt(splitIndex)

dataOut.writeInt(func.length)
dataOut.write(func, 0, func.length)

// R worker process input serialization flag
dataOut.writeInt(if (parentSerialized) 1 else 0)
// R worker process output serialization flag
dataOut.writeInt(if (dataSerialized) 1 else 0)

dataOut.writeInt(packageNames.length)
dataOut.write(packageNames, 0, packageNames.length)

dataOut.writeInt(functionDependencies.length)
dataOut.write(functionDependencies, 0, functionDependencies.length)

dataOut.writeInt(broadcastVars.length)
broadcastVars.foreach { broadcast =>
// TODO(shivaram): Read a Long in R to avoid this cast
dataOut.writeInt(broadcast.id.toInt)
// TODO: Pass a byte array from R to avoid this cast ?
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
dataOut.writeInt(broadcastByteArr.length)
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
}

dataOut.writeInt(numPartitions)
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
val printOut = new PrintStream(stream)
val dataOut = new DataOutputStream(stream)

dataOut.writeInt(splitIndex)

dataOut.writeInt(func.length)
dataOut.write(func, 0, func.length)

// R worker process input serialization flag
dataOut.writeInt(if (parentSerialized) 1 else 0)
// R worker process output serialization flag
dataOut.writeInt(if (dataSerialized) 1 else 0)

dataOut.writeInt(packageNames.length)
dataOut.write(packageNames, 0, packageNames.length)

dataOut.writeInt(functionDependencies.length)
dataOut.write(functionDependencies, 0, functionDependencies.length)

dataOut.writeInt(broadcastVars.length)
broadcastVars.foreach { broadcast =>
// TODO(shivaram): Read a Long in R to avoid this cast
dataOut.writeInt(broadcast.id.toInt)
// TODO: Pass a byte array from R to avoid this cast ?
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
dataOut.writeInt(broadcastByteArr.length)
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
}

if (!iter.hasNext) {
dataOut.writeInt(0)
} else {
dataOut.writeInt(1)
}
dataOut.writeInt(numPartitions)

for (elem <- iter) {
if (parentSerialized) {
val elemArr = elem.asInstanceOf[Array[Byte]]
dataOut.writeInt(elemArr.length)
dataOut.write(elemArr, 0, elemArr.length)
if (!iter.hasNext) {
dataOut.writeInt(0)
} else {
printOut.println(elem)
dataOut.writeInt(1)
}

for (elem <- iter) {
if (parentSerialized) {
val elemArr = elem.asInstanceOf[Array[Byte]]
dataOut.writeInt(elemArr.length)
dataOut.write(elemArr, 0, elemArr.length)
} else {
printOut.println(elem)
}
}
}

printOut.flush()
dataOut.flush()
stream.flush()
stream.close()
printOut.flush()
dataOut.flush()
stream.flush()
stream.close()

// NOTE: We need to write out the temp file before writing out the
// file name to stdin. Otherwise the R process could read partial state
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
val printOutStd = new PrintStream(streamStd)
printOutStd.println(tempFileName)
printOutStd.println(rLibDir)
printOutStd.println(tempFileIn.getAbsolutePath())
printOutStd.flush()

streamStd.close()
} catch {
// TODO: We should propogate this error to the task thread
case e: Exception =>
System.err.println("R Writer thread got an exception " + e)
e.printStackTrace()
}
}
}.start()

Expand Down

0 comments on commit 179aa75

Please sign in to comment.