Skip to content

Commit

Permalink
[SPARK-45660] Re-use Literal objects in ComputeCurrentTime rule
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The ComputeCurrentTime optimizer rule does produce unique timestamp Literals for current time expressions of a query. For CurrentDate and LocalTimestamp the Literal objects are not re-used though, but semantically equal objects are created for each instance. This can cost unnecessary much memory in case there are many such Literal objects.

This PR adds a map that caches timestamp literals in case they are used more than once.

### Why are the changes needed?

A query that has a lot of equal literals could use unnecessary high amounts of memory

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added a new Unit Test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#43524 from olaky/unique-timestamp-replacement-literals.

Authored-by: Ole Sasse <ole.sasse@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
olaky authored and MaxGekk committed Oct 25, 2023
1 parent 94607dd commit 3444537
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import java.time.{Instant, LocalDateTime}
import java.time.{Instant, LocalDateTime, ZoneId}

import org.apache.spark.sql.catalyst.CurrentUserContext
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -79,6 +79,8 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
val currentTimestampMicros = instantToMicros(instant)
val currentTime = Literal.create(currentTimestampMicros, TimestampType)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
val currentDates = collection.mutable.HashMap.empty[ZoneId, Literal]
val localTimestamps = collection.mutable.HashMap.empty[ZoneId, Literal]

def transformCondition(treePatternbits: TreePatternBits): Boolean = {
treePatternbits.containsPattern(CURRENT_LIKE)
Expand All @@ -88,12 +90,17 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
case subQuery =>
subQuery.transformAllExpressionsWithPruning(transformCondition) {
case cd: CurrentDate =>
Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
currentDates.getOrElseUpdate(cd.zoneId, {
Literal.create(
DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
})
case CurrentTimestamp() | Now() => currentTime
case CurrentTimeZone() => timezone
case localTimestamp: LocalTimestamp =>
val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
localTimestamps.getOrElseUpdate(localTimestamp.zoneId, {
val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
})
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.concurrent.duration._
import scala.jdk.CollectionConverters.MapHasAsScala

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Expression, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -135,6 +135,34 @@ class ComputeCurrentTimeSuite extends PlanTest {
assert(offsetsFromQuarterHour.size == 1)
}

test("No duplicate literals") {
def checkLiterals(f: (String) => Expression, expected: Int): Unit = {
val timestamps = ZoneId.SHORT_IDS.asScala.flatMap { case (zoneId, _) =>
// Request each timestamp multiple times.
(1 to 5).map { _ => Alias(f(zoneId), zoneId)() }
}.toSeq

val input = Project(timestamps, LocalRelation())
val plan = Optimize.execute(input).asInstanceOf[Project]

val uniqueLiteralObjectIds = new scala.collection.mutable.HashSet[Int]
plan.transformWithSubqueries { case subQuery =>
subQuery.transformAllExpressions { case literal: Literal =>
uniqueLiteralObjectIds += System.identityHashCode(literal)
literal
}
}

assert(expected === uniqueLiteralObjectIds.size)
}

val numTimezones = ZoneId.SHORT_IDS.size
checkLiterals({ _: String => CurrentTimestamp() }, 1)
checkLiterals({ zoneId: String => LocalTimestamp(Some(zoneId)) }, numTimezones)
checkLiterals({ _: String => Now() }, 1)
checkLiterals({ zoneId: String => CurrentDate(Some(zoneId)) }, numTimezones)
}

private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = {
val literals = new scala.collection.mutable.ArrayBuffer[T]
plan.transformWithSubqueries { case subQuery =>
Expand Down

0 comments on commit 3444537

Please sign in to comment.