diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 4052ccd64965d..18c85999312df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import java.time.{Instant, LocalDateTime} +import java.time.{Instant, LocalDateTime, ZoneId} import org.apache.spark.sql.catalyst.CurrentUserContext import org.apache.spark.sql.catalyst.expressions._ @@ -79,6 +79,8 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { val currentTimestampMicros = instantToMicros(instant) val currentTime = Literal.create(currentTimestampMicros, TimestampType) val timezone = Literal.create(conf.sessionLocalTimeZone, StringType) + val currentDates = collection.mutable.HashMap.empty[ZoneId, Literal] + val localTimestamps = collection.mutable.HashMap.empty[ZoneId, Literal] def transformCondition(treePatternbits: TreePatternBits): Boolean = { treePatternbits.containsPattern(CURRENT_LIKE) @@ -88,12 +90,17 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { case subQuery => subQuery.transformAllExpressionsWithPruning(transformCondition) { case cd: CurrentDate => - Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType) + currentDates.getOrElseUpdate(cd.zoneId, { + Literal.create( + DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType) + }) case CurrentTimestamp() | Now() => currentTime case CurrentTimeZone() => timezone case localTimestamp: LocalTimestamp => - val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId) - Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType) + localTimestamps.getOrElseUpdate(localTimestamp.zoneId, { + val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId) + Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType) + }) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index 8b76cc383c5a6..447d77855fb3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration._ import scala.jdk.CollectionConverters.MapHasAsScala import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Expression, InSubquery, ListQuery, Literal, LocalTimestamp, Now} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -135,6 +135,34 @@ class ComputeCurrentTimeSuite extends PlanTest { assert(offsetsFromQuarterHour.size == 1) } + test("No duplicate literals") { + def checkLiterals(f: (String) => Expression, expected: Int): Unit = { + val timestamps = ZoneId.SHORT_IDS.asScala.flatMap { case (zoneId, _) => + // Request each timestamp multiple times. + (1 to 5).map { _ => Alias(f(zoneId), zoneId)() } + }.toSeq + + val input = Project(timestamps, LocalRelation()) + val plan = Optimize.execute(input).asInstanceOf[Project] + + val uniqueLiteralObjectIds = new scala.collection.mutable.HashSet[Int] + plan.transformWithSubqueries { case subQuery => + subQuery.transformAllExpressions { case literal: Literal => + uniqueLiteralObjectIds += System.identityHashCode(literal) + literal + } + } + + assert(expected === uniqueLiteralObjectIds.size) + } + + val numTimezones = ZoneId.SHORT_IDS.size + checkLiterals({ _: String => CurrentTimestamp() }, 1) + checkLiterals({ zoneId: String => LocalTimestamp(Some(zoneId)) }, numTimezones) + checkLiterals({ _: String => Now() }, 1) + checkLiterals({ zoneId: String => CurrentDate(Some(zoneId)) }, numTimezones) + } + private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = { val literals = new scala.collection.mutable.ArrayBuffer[T] plan.transformWithSubqueries { case subQuery =>