diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index b161651c4e6a3..6fa7ee0c38185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -36,10 +36,19 @@ case class EventTimeStats(var max: Long, var min: Long, var avg: Double, var cou } def merge(that: EventTimeStats): Unit = { - this.max = math.max(this.max, that.max) - this.min = math.min(this.min, that.min) - this.count += that.count - this.avg += (that.avg - this.avg) * that.count / this.count + if (that.count == 0) { + // no-op + } else if (this.count == 0) { + this.max = that.max + this.min = that.min + this.count = that.count + this.avg = that.avg + } else { + this.max = math.max(this.max, that.max) + this.min = math.min(this.min, that.min) + this.count += that.count + this.avg += (that.avg - this.avg) * that.count / this.count + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..a51f0869ffa4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -38,9 +38,9 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche sqlContext.streams.active.foreach(_.stop()) } - test("EventTimeStats") { - val epsilon = 10E-6 + private val epsilon = 10E-6 + test("EventTimeStats") { val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5) stats.add(80L) stats.max should be (100) @@ -57,7 +57,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } test("EventTimeStats: avg on large values") { - val epsilon = 10E-6 val largeValue = 10000000000L // 10B // Make sure `largeValue` will cause overflow if we use a Long sum to calc avg. assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue)) @@ -75,6 +74,33 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche stats.avg should be ((largeValue + 0.5) +- epsilon) } + test("EventTimeStats: zero merge zero") { + val stats = EventTimeStats.zero + val stats2 = EventTimeStats.zero + stats.merge(stats2) + stats should be (EventTimeStats.zero) + } + + test("EventTimeStats: non-zero merge zero") { + val stats = EventTimeStats(max = 10, min = 1, avg = 5.0, count = 3) + val stats2 = EventTimeStats.zero + stats.merge(stats2) + stats.max should be (10L) + stats.min should be (1L) + stats.avg should be (5.0 +- epsilon) + stats.count should be (3L) + } + + test("EventTimeStats: zero merge non-zero") { + val stats = EventTimeStats.zero + val stats2 = EventTimeStats(max = 10, min = 1, avg = 5.0, count = 3) + stats.merge(stats2) + stats.max should be (10L) + stats.min should be (1L) + stats.avg should be (5.0 +- epsilon) + stats.count should be (3L) + } + test("error on bad column") { val inputData = MemoryStream[Int].toDF() val e = intercept[AnalysisException] {