Skip to content

Commit

Permalink
[SPARK-49836][SQL][SS] Fix possibly broken query when window is provi…
Browse files Browse the repository at this point in the history
…ded to window/session_window fn

### What changes were proposed in this pull request?

This PR fixes the correctness issue about losing operators during analysis - it happens when window is provided to window()/session_window() function.

The rule `TimeWindowing` and `SessionWindowing` are responsible to resolve the time window functions. When the window function has `window` as parameter (time column) (in other words, building time window from time window), the rule wraps window with WindowTime function so that the rule ResolveWindowTime will further resolve this. (And TimeWindowing/SessionWindowing will resolve this again against the result of ResolveWindowTime.)

The issue is that the rule uses "return" for the above, which intends to have "early return" as the other branch is too long compared to this branch. This unfortunately does not work as intended - the intention is just to go out of current local scope (mostly end of curly brace), but it seems to break the loop of execution in "outer" side.
(I haven't debugged further but it's simply clear that it doesn't work as intended.)

Quoting from Scala doc:

> Nonlocal returns are implemented by throwing and catching scala.runtime.NonLocalReturnException-s.

It's not super clear where NonLocalReturnException is caught in the call stack; it might exit the execution for much broader scope (context) than expected. And it's finally deprecated in Scala 3.2 and likely be removed in future.

https://dotty.epfl.ch/docs/reference/dropped-features/nonlocal-returns.html

Interestingly it does not break every query for chained time window aggregations. Spark already has several tests with DataFrame API and they haven't failed. The reproducer in community report is using SQL statement - where each aggregation is considered as subquery.

This PR fixes the rule to NOT use early return and instead have a huge if else.

### Why are the changes needed?

Described in above.

### Does this PR introduce _any_ user-facing change?

Yes, this fixes the possible query breakage. The impacted workloads may not be very huge as chained time window aggregations is an advanced usage, and it does not break every query for the usage.

### How was this patch tested?

New UTs.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48309 from HeartSaVioR/SPARK-49836.

Lead-authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Co-authored-by: Andrzej Zera <andrzejzera@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
HeartSaVioR and andrzejzera committed Oct 4, 2024
1 parent 0c653db commit d8c04cf
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,85 +87,86 @@ object TimeWindowing extends Rule[LogicalPlan] {

val window = windowExpressions.head

// time window is provided as time column of window function, replace it with WindowTime
if (StructType.acceptsType(window.timeColumn.dataType)) {
return p.transformExpressions {
p.transformExpressions {
case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn))
}
}

val metadata = window.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(TimeWindow.marker, true)
.build()
} else {
val metadata = window.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

def getWindow(i: Int, dataType: DataType): Expression = {
val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
val remainder = (timestamp - window.startTime) % window.slideDuration
val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
remainder + window.slideDuration)), Some(remainder))
val windowStart = lastStart - i * window.slideDuration
val windowEnd = windowStart + window.windowDuration
val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(TimeWindow.marker, true)
.build()

// We make sure value fields are nullable since the dataType of TimeWindow defines them
// as nullable.
CreateNamedStruct(
Literal(WINDOW_START) ::
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
Literal(WINDOW_END) ::
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
Nil)
}
def getWindow(i: Int, dataType: DataType): Expression = {
val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
val remainder = (timestamp - window.startTime) % window.slideDuration
val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
remainder + window.slideDuration)), Some(remainder))
val windowStart = lastStart - i * window.slideDuration
val windowEnd = windowStart + window.windowDuration

// We make sure value fields are nullable since the dataType of TimeWindow defines them
// as nullable.
CreateNamedStruct(
Literal(WINDOW_START) ::
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
Literal(WINDOW_END) ::
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
Nil)
}

val windowAttr = AttributeReference(
WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
val windowAttr = AttributeReference(
WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()

if (window.windowDuration == window.slideDuration) {
val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
if (window.windowDuration == window.slideDuration) {
val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))

val replacedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}
val replacedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

// For backwards compatibility we add a filter to filter out nulls
val filterExpr = IsNotNull(window.timeColumn)
// For backwards compatibility we add a filter to filter out nulls
val filterExpr = IsNotNull(window.timeColumn)

replacedPlan.withNewChildren(
Project(windowStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else {
val overlappingWindows =
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
val windows =
Seq.tabulate(overlappingWindows)(i =>
getWindow(i, window.timeColumn.dataType))

val projections = windows.map(_ +: child.output)

// When the condition windowDuration % slideDuration = 0 is fulfilled,
// the estimation of the number of windows becomes exact one,
// which means all produced windows are valid.
val filterExpr =
if (window.windowDuration % window.slideDuration == 0) {
IsNotNull(window.timeColumn)
replacedPlan.withNewChildren(
Project(windowStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else {
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
window.timeColumn < windowAttr.getField(WINDOW_END)
val overlappingWindows =
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
val windows =
Seq.tabulate(overlappingWindows)(i =>
getWindow(i, window.timeColumn.dataType))

val projections = windows.map(_ +: child.output)

// When the condition windowDuration % slideDuration = 0 is fulfilled,
// the estimation of the number of windows becomes exact one,
// which means all produced windows are valid.
val filterExpr =
if (window.windowDuration % window.slideDuration == 0) {
IsNotNull(window.timeColumn)
} else {
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
window.timeColumn < windowAttr.getField(WINDOW_END)
}

val substitutedPlan = Filter(filterExpr,
Expand(projections, windowAttr +: child.output, child))

val renamedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

renamedPlan.withNewChildren(substitutedPlan :: Nil)
}

val substitutedPlan = Filter(filterExpr,
Expand(projections, windowAttr +: child.output, child))

val renamedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
Expand Down Expand Up @@ -210,74 +211,74 @@ object SessionWindowing extends Rule[LogicalPlan] {
val session = sessionExpressions.head

if (StructType.acceptsType(session.timeColumn.dataType)) {
return p transformExpressions {
p transformExpressions {
case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn))
}
}
} else {
val metadata = session.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

val metadata = session.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}
val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(SessionWindow.marker, true)
.build()

val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(SessionWindow.marker, true)
.build()

val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()

val sessionStart =
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
val gapDuration = session.gapDuration match {
case expr if expr.dataType == CalendarIntervalType =>
expr
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
case other =>
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
}
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
session.timeColumn.dataType, LongType)

// We make sure value fields are nullable since the dataType of SessionWindow defines them
// as nullable.
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
.castNullable() ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
.castNullable() ::
Nil)

val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()

val sessionStart =
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
val gapDuration = session.gapDuration match {
case expr if expr.dataType == CalendarIntervalType =>
expr
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
case other =>
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
}
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
session.timeColumn.dataType, LongType)

val replacedPlan = p transformExpressions {
case s: SessionWindow => sessionAttr
}
// We make sure value fields are nullable since the dataType of SessionWindow defines them
// as nullable.
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
.castNullable() ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
.castNullable() ::
Nil)

val filterByTimeRange = if (gapDuration.foldable) {
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
interval == null || interval.months + interval.days + interval.microseconds <= 0
} else {
true
}
val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))

// As same as tumbling window, we add a filter to filter out nulls.
// And we also filter out events with negative or zero or invalid gap duration.
val filterExpr = if (filterByTimeRange) {
IsNotNull(session.timeColumn) &&
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
} else {
IsNotNull(session.timeColumn)
}
val replacedPlan = p transformExpressions {
case s: SessionWindow => sessionAttr
}

replacedPlan.withNewChildren(
Filter(filterExpr,
Project(sessionStruct +: child.output, child)) :: Nil)
val filterByTimeRange = if (gapDuration.foldable) {
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
interval == null || interval.months + interval.days + interval.microseconds <= 0
} else {
true
}

// As same as tumbling window, we add a filter to filter out nulls.
// And we also filter out events with negative or zero or invalid gap duration.
val filterExpr = if (filterByTimeRange) {
IsNotNull(session.timeColumn) &&
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
} else {
IsNotNull(session.timeColumn)
}

replacedPlan.withNewChildren(
Filter(filterExpr,
Project(sessionStruct +: child.output, child)) :: Nil)
}
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
}
}
}

test("SPARK-49836 using window fn with window as parameter should preserve parent operator") {
withTempView("clicks") {
val df = Seq(
// small window: [00:00, 01:00), user1, 2
("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
// small window: [01:00, 02:00), user2, 2
("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
// small window: [03:00, 04:00), user1, 1
("2024-09-30 00:03:30", "user1"),
// small window: [11:00, 12:00), user1, 3
("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
("2024-09-30 00:11:45", "user1")
).toDF("eventTime", "userId")

// session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 /
// (12:00, 12:05), user1, 3

df.createOrReplaceTempView("clicks")

val aggregatedData = spark.sql(
"""
|SELECT
| userId,
| avg(cpu_large.numClicks) AS clicksPerSession
|FROM
|(
| SELECT
| session_window(small_window, '5 minutes') AS session,
| userId,
| sum(numClicks) AS numClicks
| FROM
| (
| SELECT
| window(eventTime, '1 minute') AS small_window,
| userId,
| count(*) AS numClicks
| FROM clicks
| GROUP BY window, userId
| ) cpu_small
| GROUP BY session_window, userId
|) cpu_large
|GROUP BY userId
|""".stripMargin)

checkAnswer(
aggregatedData,
Seq(Row("user1", 3), Row("user2", 2))
)
}
}
}
Loading

0 comments on commit d8c04cf

Please sign in to comment.