Skip to content

Commit

Permalink
Cache before partitioned-write, provide unpersist handle (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi authored Oct 26, 2022
1 parent 74c8075 commit 5786959
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 59 deletions.
12 changes: 12 additions & 0 deletions PARTITIONING.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ If you need each file to further be sorted by additional columns, e.g. `ts`, the

ds.repartitionByRange($"property", $"id")
.sortWithinPartitions($"property", $"id", $"ts")
.cache // this is needed for Spark 3.0 to 3.3 with AQE enabled: SPARK-40588
.write
.partitionBy("property")
.csv("file.csv")
Expand All @@ -100,15 +101,26 @@ e.g. the date-representation of the `ts` column.
ds.withColumn("date", $"ts".cast(DateType))
.repartitionByRange($"date", $"id")
.sortWithinPartitions($"date", $"id", $"ts")
.cache // this is needed for Spark 3.0 to 3.3 with AQE enabled: SPARK-40588
.write
.partitionBy("date")
.csv("file.csv")

All those above constructs can be replaced with a single meaningful operation:

ds.writePartitionedBy(Seq($"ts".cast(DateType).as("date")), Seq($"id"), Seq($"ts"))
.csv("file.csv")

For Spark 3.0 to 3.3 with AQE enabled (see [SPARK-40588](https://issues.apache.org/jira/browse/SPARK-40588)),
`writePartitionedBy` has to cache an internally created DataFrame. This can be unpersisted after writing
is finished. Provide an `UnpersistHandle` for this purpose:

val unpersist = UnpersistHandle()

ds.writePartitionedBy(…, unpersistHandle = Some(unpersist))
.csv("file.csv")

unpersist()

<!--
# Other Approaches
Expand Down
30 changes: 25 additions & 5 deletions src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import org.apache.spark.sql.DataFrame
class UnpersistHandle {
var df: Option[DataFrame] = None

private[spark] def setDataFrame(dataframe: DataFrame): Unit = {
if (df.isDefined) throw new IllegalStateException("DataFrame has been set already. It cannot be reused once used with withRowNumbers.")
private[spark] def setDataFrame(dataframe: DataFrame): DataFrame = {
if (df.isDefined) throw new IllegalStateException("DataFrame has been set already, it cannot be reused.")
this.df = Some(dataframe)
dataframe
}

def apply(): Unit = {
Expand All @@ -24,13 +25,32 @@ class UnpersistHandle {
}
}

class NoopUnpersistHandle extends UnpersistHandle{
override def setDataFrame(dataframe: DataFrame): Unit = {}
case class SilentUnpersistHandle() extends UnpersistHandle {
override def apply(): Unit = {
this.df.foreach(_.unpersist())
}

override def apply(blocking: Boolean): Unit = {
this.df.foreach(_.unpersist(blocking))
}
}

case class NoopUnpersistHandle() extends UnpersistHandle{
override def setDataFrame(dataframe: DataFrame): DataFrame = dataframe
override def apply(): Unit = {}
override def apply(blocking: Boolean): Unit = {}
}

object UnpersistHandle {
val Noop: NoopUnpersistHandle = NoopUnpersistHandle()
def apply(): UnpersistHandle = new UnpersistHandle()
val Noop = new NoopUnpersistHandle()

def withUnpersist[T](blocking: Boolean = false)(func: UnpersistHandle => T): T = {
val handle = SilentUnpersistHandle()
try {
func(handle)
} finally {
handle(blocking)
}
}
}
93 changes: 85 additions & 8 deletions src/main/scala/uk/co/gresearch/spark/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

package uk.co.gresearch

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import uk.co.gresearch.spark.group.SortedGroupByDataset

package object spark {
import java.io.IOException
import scala.util.Properties

package object spark extends Logging {

/**
* Provides a prefix that makes any string distinct w.r.t. the given strings.
Expand All @@ -34,6 +39,46 @@ package object spark {
"_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1)
}

private[spark] lazy val getSparkVersion: Option[String] = {
val scalaCompatVersionOpt = Properties.releaseVersion.map(_.split("\\.").take(2).mkString("."))
scalaCompatVersionOpt.flatMap { scalaCompatVersion =>
val propFilePath = s"META-INF/maven/org.apache.spark/spark-sql_$scalaCompatVersion/pom.properties"
Option(ClassLoader.getSystemClassLoader.getResourceAsStream(propFilePath)).flatMap { in =>
val props = try {
val props = new java.util.Properties()
props.load(in)
Some(props)
} catch {
case _: IOException => None
}

props.flatMap { props =>
val ver = Option(props.getProperty("version"))
val group = Option(props.getProperty("groupId"))
val artifact = Option(props.getProperty("artifactId"))

ver.filter(_ =>
group.exists(_.equals("org.apache.spark")) &&
artifact.exists(_.equals(s"spark-sql_$scalaCompatVersion"))
)
}
}
}
}

private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = {
val enabled = ds.sparkSession.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString)
ds.sparkSession.conf.get(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key,
SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString
).equalsIgnoreCase("true") && getSparkVersion.exists(ver =>
ver.startsWith("3.0.") || ver.startsWith("3.1.") || ver.startsWith("3.2.") || ver.startsWith("3.3.")
)
}

private[spark] def info(msg: String): Unit = logInfo(msg)
private[spark] def warning(msg: String): Unit = logWarning(msg)

/**
* Encloses the given strings with backticks if needed. Multiple strings will be enclosed individually and
* concatenated with dots (`.`).
Expand Down Expand Up @@ -88,38 +133,69 @@ package object spark {
* file per `partitionBy` partition only. Rows within the partition files are also sorted,
* if partitionOrder is defined.
*
* Note: With Spark 3.0, 3.1, 3.2 and 3.3 and AQE enabled, an intermediate DataFrame is being
* cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588.
* That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method.
*
* Calling:
* {{{
* df.writePartitionedBy(Seq("a"), Seq("b"), Seq("c"), Some(10), Seq($"a", concat($"b", $"c")))
* val unpersist = UnpersistHandle()
* val writer = df.writePartitionedBy(Seq("a"), Seq("b"), Seq("c"), Some(10), Seq($"a", concat($"b", $"c")), unpersist)
* writer.parquet("data.parquet")
* unpersist()
* }}}
*
* is equivalent to:
* {{{
* df.repartitionByRange(10, $"a", $"b")
* .sortWithinPartitions($"a", $"b", $"c")
* .select($"a", concat($"b", $"c"))
* .write
* .partitionBy("a")
* val cached =
* df.repartitionByRange(10, $"a", $"b")
* .sortWithinPartitions($"a", $"b", $"c")
* .cache
*
* val writer =
* cached
* .select($"a", concat($"b", $"c"))
* .write
* .partitionBy("a")
*
* writer.parquet("data.parquet")
*
* cached.unpersist
* }}}
*
* @param partitionColumns columns used for partitioning
* @param moreFileColumns columns where individual values are written to a single file
* @param moreFileOrder additional columns to sort partition files
* @param partitions optional number of partition files
* @param writtenProjection additional transformation to be applied before calling write
* @param unpersistHandle handle to unpersist internally created DataFrame after writing
* @return configured DataFrameWriter
*/
def writePartitionedBy(partitionColumns: Seq[Column],
moreFileColumns: Seq[Column] = Seq.empty,
moreFileOrder: Seq[Column] = Seq.empty,
partitions: Option[Int] = None,
writtenProjection: Option[Seq[Column]] = None): DataFrameWriter[Row] = {
writtenProjection: Option[Seq[Column]] = None,
unpersistHandle: Option[UnpersistHandle] = None): DataFrameWriter[Row] = {
if (partitionColumns.isEmpty)
throw new IllegalArgumentException(s"partition columns must not be empty")

if (partitionColumns.exists(!_.expr.isInstanceOf[NamedExpression]))
throw new IllegalArgumentException(s"partition columns must be named: ${partitionColumns.mkString(",")}")

val requiresCaching = writePartitionedByRequiresCaching(ds)
(requiresCaching, unpersistHandle.isDefined) match {
case (true, false) =>
warning("Partitioned-writing with AQE enabled and Spark 3.0 to 3.3 requires caching " +
"an intermediate DataFrame, which calling code has to unpersist once writing is done. " +
"Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop.")
case (false, true) if !unpersistHandle.get.isInstanceOf[NoopUnpersistHandle] =>
info("UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as " +
"partitioned-writing with AQE disabled or Spark 3.4 and above does not require caching intermediate DataFrame.")
unpersistHandle.get.setDataFrame(ds.sparkSession.emptyDataFrame)
case _ =>
}

val partitionColumnsMap = partitionColumns.map(c => c.expr.asInstanceOf[NamedExpression].name -> c).toMap
val partitionColumnNames = partitionColumnsMap.keys.map(col).toSeq
val rangeColumns = partitionColumnNames ++ moreFileColumns
Expand All @@ -130,6 +206,7 @@ package object spark {
.when(partitions.isDefined).call(_.repartitionByRange(partitions.get, rangeColumns: _*))
.sortWithinPartitions(sortColumns: _*)
.when(writtenProjection.isDefined).call(_.select(writtenProjection.get: _*))
.when(requiresCaching && unpersistHandle.isDefined).call(unpersistHandle.get.setDataFrame(_))
.write
.partitionBy(partitionColumnsMap.keys.toSeq: _*)
}
Expand Down
128 changes: 90 additions & 38 deletions src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,94 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
val emptyDataset: Dataset[Value] = spark.emptyDataset[Value]
val emptyDataFrame: DataFrame = spark.createDataFrame(Seq.empty[Value])

test("get spark version") {
assert(getSparkVersion.isDefined)
}

Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, NONE).foreach { level =>
Seq(
("UnpersistHandle", UnpersistHandle()),
("SilentUnpersistHandle", SilentUnpersistHandle())
).foreach { case (handleClass, unpersist) =>
test(s"$handleClass does unpersist set DataFrame with $level") {
val cacheManager = spark.sharedState.cacheManager
cacheManager.clearCache()
assert(cacheManager.isEmpty === true)

val df = spark.emptyDataFrame
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

unpersist.setDataFrame(df)
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

df.cache()
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === true)

unpersist(blocking = true)
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

// calling this twice does not throw any errors
unpersist()
}
}
}

Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, NONE).foreach { level =>
test(s"NoopUnpersistHandle does not unpersist set DataFrame with $level") {
val cacheManager = spark.sharedState.cacheManager
cacheManager.clearCache()
assert(cacheManager.isEmpty === true)

val df = spark.emptyDataFrame
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

val unpersist = UnpersistHandle.Noop
unpersist.setDataFrame(df)
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

df.cache()
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === true)

unpersist(blocking = true)
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === true)

// calling this twice does not throw any errors
unpersist()
}
}

Seq(
("UnpersistHandle", UnpersistHandle()),
("SilentUnpersistHandle", SilentUnpersistHandle())
).foreach { case (handleClass, unpersist) =>
test(s"$handleClass throws on setting DataFrame twice") {
unpersist.setDataFrame(spark.emptyDataFrame)
assert(intercept[IllegalStateException] {
unpersist.setDataFrame(spark.emptyDataFrame)
}.getMessage === s"DataFrame has been set already, it cannot be reused.")
}
}

test("UnpersistHandle throws on unpersist if no DataFrame is set") {
val unpersist = UnpersistHandle()
assert(intercept[IllegalStateException] { unpersist() }.getMessage === s"DataFrame has to be set first")
}

test("UnpersistHandle throws on unpersist with blocking if no DataFrame is set") {
val unpersist = UnpersistHandle()
assert(intercept[IllegalStateException] { unpersist(blocking = true) }.getMessage === s"DataFrame has to be set first")
}

test("SilentUnpersistHandle does not throw on unpersist if no DataFrame is set") {
val unpersist = SilentUnpersistHandle()
unpersist()
}

test("SilentUnpersistHandle does not throw on unpersist with blocking if no DataFrame is set") {
val unpersist = SilentUnpersistHandle()
unpersist(blocking = true)
}

test("backticks") {
assert(backticks("column") === "column")
assert(backticks("a.column") === "`a.column`")
Expand Down Expand Up @@ -224,13 +312,13 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
doTestWithRowNumbers { df => df.repartition(100) }($"id".desc)()
}

Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY).foreach { level =>
Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, NONE).foreach { level =>
test(s"global row number with $level") {
doTestWithRowNumbers(storageLevel = level)($"id")()
}
}

Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY).foreach { level =>
Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, NONE).foreach { level =>
test(s"global row number allows to unpersist with $level") {
val cacheManager = spark.sharedState.cacheManager
cacheManager.clearCache()
Expand Down Expand Up @@ -323,42 +411,6 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
assert(incorrectRowNumbers === 0)
}

test("UnpersistHandle does unpersit set DataFrame") {
val cacheManager = spark.sharedState.cacheManager
cacheManager.clearCache()
assert(cacheManager.isEmpty === true)

val unpersist = UnpersistHandle()
val df = spark.emptyDataFrame
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

df.cache()
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === true)

unpersist.setDataFrame(df)
unpersist(blocking = true)
assert(cacheManager.lookupCachedData(spark.emptyDataFrame).isDefined === false)

// calling this twice does not throw any errors
unpersist()
}

test("UnpersistHandle throws on unpersist if no DataFrame is set") {
val unpersist = UnpersistHandle()
assert(intercept[IllegalStateException] { unpersist() }.getMessage === s"DataFrame has to be set first")
}

test("UnpersistHandle throws on unpersist with blocking if no DataFrame is set") {
val unpersist = UnpersistHandle()
assert(intercept[IllegalStateException] { unpersist(blocking = true) }.getMessage === s"DataFrame has to be set first")
}

test("UnpersistHandle throws on setting DataFrame twice") {
val unpersist = UnpersistHandle()
unpersist.setDataFrame(spark.emptyDataFrame)
assert(intercept[IllegalStateException] { unpersist.setDataFrame(spark.emptyDataFrame) }.getMessage === s"DataFrame has been set already. It cannot be reused once used with withRowNumbers.")
}

}

object SparkSuite {
Expand Down
Loading

0 comments on commit 5786959

Please sign in to comment.