Skip to content

Commit

Permalink
pass calendarinterval to timestampAddInterval
Browse files Browse the repository at this point in the history
  • Loading branch information
LinhongLiu committed Oct 31, 2019
1 parent 0e87e2d commit fb1591e
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2652,7 +2652,7 @@ object Sequence {
arr(i) = fromLong(t / scale)
i += 1
t = timestampAddInterval(
startMicros, i * stepMonths, i * stepDays, i * stepMicros, zoneId)
startMicros, new CalendarInterval(i * stepMonths, i * stepDays, i * stepMicros), zoneId)
}

// truncate array to the correct length
Expand Down Expand Up @@ -2715,7 +2715,10 @@ object Sequence {
| $arr[$i] = ($elemType) ($t / ${scale}L);
| $i += 1;
| $t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
| $startMicros, $i * $stepMonths, $i * $stepDays, $i * $stepMicros, $zid);
| $startMicros,
| new org.apache.spark.unsafe.types.CalendarInterval(
| $i * $stepMonths, $i * $stepDays, $i * $stepMicros),
| $zid);
| }
|
| if ($arr.length > $i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1089,15 +1089,14 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], itvl.months, itvl.days, itvl.microseconds, zoneId)
DateTimeUtils.timestampAddInterval(start.asInstanceOf[Long], itvl, zoneId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)"""
s"""$dtu.timestampAddInterval($sd, $i, $zid)"""
})
}
}
Expand Down Expand Up @@ -1204,15 +1203,14 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.days, 0 - itvl.microseconds, zoneId)
DateTimeUtils.timestampAddInterval(start.asInstanceOf[Long], itvl.negate(), zoneId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.days, 0 - $i.microseconds, $zid)"""
s"""$dtu.timestampAddInterval($sd, $i.negate(), $zid)"""
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,13 @@ object DateTimeUtils {
*/
def timestampAddInterval(
start: SQLTimestamp,
months: Int,
days: Int,
microseconds: Long,
interval: CalendarInterval,
zoneId: ZoneId): SQLTimestamp = {
val resultTimestamp = microsToInstant(start)
.atZone(zoneId)
.plusMonths(months)
.plusDays(days)
.plus(microseconds, ChronoUnit.MICROS)
.plusMonths(interval.months)
.plusDays(interval.days)
.plus(interval.microseconds, ChronoUnit.MICROS)
instantToMicros(resultTimestamp.toInstant)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ object IntervalUtils {
val days = if (m.group(2) == null) {
0
} else {
toLongWithRange("day", m.group(3), 0, Integer.MAX_VALUE)
toLongWithRange("day", m.group(3), 0, Integer.MAX_VALUE).toInt
}
var hours: Long = 0L
var minutes: Long = 0L
Expand Down Expand Up @@ -238,7 +238,7 @@ object IntervalUtils {
micros = Math.addExact(micros, Math.multiplyExact(hours, MICROS_PER_HOUR))
micros = Math.addExact(micros, Math.multiplyExact(minutes, MICROS_PER_MINUTE))
micros = Math.addExact(micros, Math.multiplyExact(seconds, DateTimeUtils.MICROS_PER_SECOND))
new CalendarInterval(0, sign * days.toInt, sign * micros)
new CalendarInterval(0, sign * days, sign * micros)
} catch {
case e: Exception =>
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
test("timestamp add months") {
val ts1 = date(1997, 2, 28, 10, 30, 0)
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
assert(timestampAddInterval(ts1, 36, 0, 123000, defaultZoneId) === ts2)
assert(timestampAddInterval(ts1, new CalendarInterval(36, 0, 123000), defaultZoneId) === ts2)

val ts3 = date(1997, 2, 27, 16, 0, 0, 0, TimeZonePST)
val ts4 = date(2000, 2, 27, 16, 0, 0, 123000, TimeZonePST)
val ts5 = date(2000, 2, 28, 0, 0, 0, 123000, TimeZoneGMT)
assert(timestampAddInterval(ts3, 36, 0, 123000, TimeZonePST.toZoneId) === ts4)
assert(timestampAddInterval(ts3, 36, 0, 123000, TimeZoneGMT.toZoneId) === ts5)
assert(timestampAddInterval(
ts3, new CalendarInterval(36, 0, 123000), TimeZonePST.toZoneId) === ts4)
assert(timestampAddInterval(
ts3, new CalendarInterval(36, 0, 123000), TimeZoneGMT.toZoneId) === ts5)
}

test("timestamp add days") {
Expand All @@ -396,16 +398,22 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {

// transit from Pacific Standard Time to Pacific Daylight Time
assert(timestampAddInterval(
ts1, 0, 0, 23 * CalendarInterval.MICROS_PER_HOUR, TimeZonePST.toZoneId) === ts2)
assert(timestampAddInterval(ts1, 0, 1, 0, TimeZonePST.toZoneId) === ts2)
ts1,
new CalendarInterval(0, 0, 23 * CalendarInterval.MICROS_PER_HOUR),
TimeZonePST.toZoneId) === ts2)
assert(timestampAddInterval(ts1, new CalendarInterval(0, 1, 0), TimeZonePST.toZoneId) === ts2)
// just a normal day
assert(timestampAddInterval(
ts3, 0, 0, 24 * CalendarInterval.MICROS_PER_HOUR, TimeZonePST.toZoneId) === ts4)
assert(timestampAddInterval(ts3, 0, 1, 0, TimeZonePST.toZoneId) === ts4)
ts3,
new CalendarInterval(0, 0, 24 * CalendarInterval.MICROS_PER_HOUR),
TimeZonePST.toZoneId) === ts4)
assert(timestampAddInterval(ts3, new CalendarInterval(0, 1, 0), TimeZonePST.toZoneId) === ts4)
// transit from Pacific Daylight Time to Pacific Standard Time
assert(timestampAddInterval(
ts5, 0, 0, 25 * CalendarInterval.MICROS_PER_HOUR, TimeZonePST.toZoneId) === ts6)
assert(timestampAddInterval(ts5, 0, 1, 0, TimeZonePST.toZoneId) === ts6)
ts5,
new CalendarInterval(0, 0, 25 * CalendarInterval.MICROS_PER_HOUR),
TimeZonePST.toZoneId) === ts6)
assert(timestampAddInterval(ts5, new CalendarInterval(0, 1, 0), TimeZonePST.toZoneId) === ts6)
}

test("monthsBetween") {
Expand Down

0 comments on commit fb1591e

Please sign in to comment.