Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24502][SQL] flaky test: UnsafeRowSerializerSuite #21518

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self
override def beforeAll() {
super.beforeAll()
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}

override def afterEach() {
try {
resetSparkContext()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
} finally {
super.afterEach()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
import java.util.Properties

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

/**
Expand All @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}

class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {

private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
Expand All @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}

test("toUnsafeRow() test helper method") {
// This currently doesnt work because the generic getter throws an exception.
// This currently doesn't work because the generic getter throws an exception.
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
assert(row.getString(0) === unsafeRow.getUTF8String(0).toString)
Expand Down Expand Up @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}

test("SPARK-10466: external sorter spilling with unsafe row serializer") {
var sc: SparkContext = null
var outputFile: File = null
val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
Utils.tryWithSafeFinally {
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")

sc = new SparkContext("local", "test", conf)
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = new UnsafeRowSerializer(numFields = 1))

// Ensure we spilled something and have to merge them later
assert(sorter.numSpills === 0)
sorter.insertAll(data)
assert(sorter.numSpills > 0)
val conf = new SparkConf()
.set("spark.shuffle.spill.initialMemoryThreshold", "1")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
.set("spark.testing.memory", "80000")
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
outputFile.deleteOnExit()
// prepare data
val converter = unsafeRowConverter(Array(IntegerType))
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)

// Merging spilled files should not throw assertion error
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
} {
// Clean up
if (sc != null) {
sc.stop()
}
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = new UnsafeRowSerializer(numFields = 1))

// restore the spark env
SparkEnv.set(oldEnv)
// Ensure we spilled something and have to merge them later
assert(sorter.numSpills === 0)
sorter.insertAll(data)
assert(sorter.numSpills > 0)

if (outputFile != null) {
outputFile.delete()
}
}
// Merging spilled files should not throw assertion error
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
}

test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
val conf = new SparkConf().set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
.asInstanceOf[RDD[Product2[Int, InternalRow]]]
val rowsRDD = spark.sparkContext.parallelize(
Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))
).asInstanceOf[RDD[Product2[Int, InternalRow]]]
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rowsRDD,
Expand Down