From dfb880951a8de55c587c1bf8b696df50eae6e68a Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 3 Feb 2019 21:19:35 +0800 Subject: [PATCH] [SPARK-26818][ML] Make MLEvents JSON ser/de safe ## What changes were proposed in this pull request? Currently, it looks it's not going to cause any virtually effective problem apparently (if I didn't misread the codes). I see one place that JSON formatted events are being used. https://github.com/apache/spark/blob/ec506bd30c2ca324c12c9ec811764081c2eb8c42/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L148 It's okay because it just logs when the exception is ignorable https://github.com/apache/spark/blob/9690eba16efe6d25261934d8b73a221972b684f3/core/src/main/scala/org/apache/spark/util/ListenerBus.scala#L111 I guess it should be best to stay safe - I don't want this unstable experimental feature breaks anything in any case. It also disables `logEvent` in `SparkListenerEvent` for the same reason. This is also to match SQL execution events side: https://github.com/apache/spark/blob/ca545f79410a464ef24e3986fac225f53bb2ef02/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L41-L57 to make ML events JSON ser/de safe. ## How was this patch tested? Manually tested, and unit tests were added. Closes #23728 from HyukjinKwon/SPARK-26818. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/ml/events.scala | 81 ++++++++++--- .../org/apache/spark/ml/MLEventsSuite.scala | 112 ++++++++++++++---- 2 files changed, 155 insertions(+), 38 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/events.scala b/mllib/src/main/scala/org/apache/spark/ml/events.scala index c51600fcca466..dc4be4dd9efda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/events.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/events.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml +import com.fasterxml.jackson.annotation.JsonIgnore + import org.apache.spark.SparkContext import org.apache.spark.annotation.Unstable import org.apache.spark.internal.Logging @@ -29,53 +31,84 @@ import org.apache.spark.sql.{DataFrame, Dataset} * after each operation (the event should document this). * * @note This is supported via [[Pipeline]] and [[PipelineModel]]. + * @note This is experimental and unstable. Do not use this unless you fully + * understand what `Unstable` means. */ @Unstable -sealed trait MLEvent extends SparkListenerEvent +sealed trait MLEvent extends SparkListenerEvent { + // Do not log ML events in event log. It should be revisited to see + // how it works with history server. + protected[spark] override def logEvent: Boolean = false +} /** * Event fired before `Transformer.transform`. */ @Unstable -case class TransformStart(transformer: Transformer, input: Dataset[_]) extends MLEvent +case class TransformStart() extends MLEvent { + @JsonIgnore var transformer: Transformer = _ + @JsonIgnore var input: Dataset[_] = _ +} + /** * Event fired after `Transformer.transform`. */ @Unstable -case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends MLEvent +case class TransformEnd() extends MLEvent { + @JsonIgnore var transformer: Transformer = _ + @JsonIgnore var output: Dataset[_] = _ +} /** * Event fired before `Estimator.fit`. */ @Unstable -case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: Dataset[_]) extends MLEvent +case class FitStart[M <: Model[M]]() extends MLEvent { + @JsonIgnore var estimator: Estimator[M] = _ + @JsonIgnore var dataset: Dataset[_] = _ +} + /** * Event fired after `Estimator.fit`. */ @Unstable -case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends MLEvent +case class FitEnd[M <: Model[M]]() extends MLEvent { + @JsonIgnore var estimator: Estimator[M] = _ + @JsonIgnore var model: M = _ +} /** * Event fired before `MLReader.load`. */ @Unstable -case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends MLEvent +case class LoadInstanceStart[T](path: String) extends MLEvent { + @JsonIgnore var reader: MLReader[T] = _ +} + /** * Event fired after `MLReader.load`. */ @Unstable -case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent +case class LoadInstanceEnd[T]() extends MLEvent { + @JsonIgnore var reader: MLReader[T] = _ + @JsonIgnore var instance: T = _ +} /** * Event fired before `MLWriter.save`. */ @Unstable -case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent +case class SaveInstanceStart(path: String) extends MLEvent { + @JsonIgnore var writer: MLWriter = _ +} + /** * Event fired after `MLWriter.save`. */ @Unstable -case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent +case class SaveInstanceEnd(path: String) extends MLEvent { + @JsonIgnore var writer: MLWriter = _ +} /** * A small trait that defines some methods to send [[org.apache.spark.ml.MLEvent]]. @@ -91,11 +124,15 @@ private[ml] trait MLEvents extends Logging { def withFitEvent[M <: Model[M]]( estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = { - val startEvent = FitStart(estimator, dataset) + val startEvent = FitStart[M]() + startEvent.estimator = estimator + startEvent.dataset = dataset logEvent(startEvent) listenerBus.post(startEvent) val model: M = func - val endEvent = FitEnd(estimator, model) + val endEvent = FitEnd[M]() + endEvent.estimator = estimator + endEvent.model = model logEvent(endEvent) listenerBus.post(endEvent) model @@ -103,34 +140,42 @@ private[ml] trait MLEvents extends Logging { def withTransformEvent( transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = { - val startEvent = TransformStart(transformer, input) + val startEvent = TransformStart() + startEvent.transformer = transformer + startEvent.input = input logEvent(startEvent) listenerBus.post(startEvent) val output: DataFrame = func - val endEvent = TransformEnd(transformer, output) + val endEvent = TransformEnd() + endEvent.transformer = transformer + endEvent.output = output logEvent(endEvent) listenerBus.post(endEvent) output } def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = { - val startEvent = LoadInstanceStart(reader, path) + val startEvent = LoadInstanceStart[T](path) + startEvent.reader = reader logEvent(startEvent) listenerBus.post(startEvent) val instance: T = func - val endEvent = LoadInstanceEnd(reader, instance) + val endEvent = LoadInstanceEnd[T]() + endEvent.reader = reader + endEvent.instance = instance logEvent(endEvent) listenerBus.post(endEvent) instance } def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = { - listenerBus.post(SaveInstanceEnd(writer, path)) - val startEvent = SaveInstanceStart(writer, path) + val startEvent = SaveInstanceStart(path) + startEvent.writer = writer logEvent(startEvent) listenerBus.post(startEvent) func - val endEvent = SaveInstanceEnd(writer, path) + val endEvent = SaveInstanceEnd(path) + endEvent.writer = writer logEvent(endEvent) listenerBus.post(endEvent) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala index 0a87328de643e..80ae0c788ac53 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MLWri import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql._ +import org.apache.spark.util.JsonProtocol class MLEventsSuite @@ -107,20 +108,48 @@ class MLEventsSuite .setStages(Array(estimator1, transformer1, estimator2)) assert(events.isEmpty) val pipelineModel = pipeline.fit(dataset1) - val expected = - FitStart(pipeline, dataset1) :: - FitStart(estimator1, dataset1) :: - FitEnd(estimator1, model1) :: - TransformStart(model1, dataset1) :: - TransformEnd(model1, dataset2) :: - TransformStart(transformer1, dataset2) :: - TransformEnd(transformer1, dataset3) :: - FitStart(estimator2, dataset3) :: - FitEnd(estimator2, model2) :: - FitEnd(pipeline, pipelineModel) :: Nil + + val event0 = FitStart[PipelineModel]() + event0.estimator = pipeline + event0.dataset = dataset1 + val event1 = FitStart[MyModel]() + event1.estimator = estimator1 + event1.dataset = dataset1 + val event2 = FitEnd[MyModel]() + event2.estimator = estimator1 + event2.model = model1 + val event3 = TransformStart() + event3.transformer = model1 + event3.input = dataset1 + val event4 = TransformEnd() + event4.transformer = model1 + event4.output = dataset2 + val event5 = TransformStart() + event5.transformer = transformer1 + event5.input = dataset2 + val event6 = TransformEnd() + event6.transformer = transformer1 + event6.output = dataset3 + val event7 = FitStart[MyModel]() + event7.estimator = estimator2 + event7.dataset = dataset3 + val event8 = FitEnd[MyModel]() + event8.estimator = estimator2 + event8.model = model2 + val event9 = FitEnd[PipelineModel]() + event9.estimator = pipeline + event9.model = pipelineModel + + val expected = Seq( + event0, event1, event2, event3, event4, event5, event6, event7, event8, event9) eventually(timeout(10 seconds), interval(1 second)) { assert(events === expected) } + // Test if they can be ser/de via JSON protocol. + assert(events.nonEmpty) + events.map(JsonProtocol.sparkEventToJson).foreach { event => + assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent]) + } } test("pipeline model transform events") { @@ -144,18 +173,41 @@ class MLEventsSuite "pipeline0", Array(transformer1, model, transformer2)) assert(events.isEmpty) val output = newPipelineModel.transform(dataset1) - val expected = - TransformStart(newPipelineModel, dataset1) :: - TransformStart(transformer1, dataset1) :: - TransformEnd(transformer1, dataset2) :: - TransformStart(model, dataset2) :: - TransformEnd(model, dataset3) :: - TransformStart(transformer2, dataset3) :: - TransformEnd(transformer2, dataset4) :: - TransformEnd(newPipelineModel, output) :: Nil + + val event0 = TransformStart() + event0.transformer = newPipelineModel + event0.input = dataset1 + val event1 = TransformStart() + event1.transformer = transformer1 + event1.input = dataset1 + val event2 = TransformEnd() + event2.transformer = transformer1 + event2.output = dataset2 + val event3 = TransformStart() + event3.transformer = model + event3.input = dataset2 + val event4 = TransformEnd() + event4.transformer = model + event4.output = dataset3 + val event5 = TransformStart() + event5.transformer = transformer2 + event5.input = dataset3 + val event6 = TransformEnd() + event6.transformer = transformer2 + event6.output = dataset4 + val event7 = TransformEnd() + event7.transformer = newPipelineModel + event7.output = output + + val expected = Seq(event0, event1, event2, event3, event4, event5, event6, event7) eventually(timeout(10 seconds), interval(1 second)) { assert(events === expected) } + // Test if they can be ser/de via JSON protocol. + assert(events.nonEmpty) + events.map(JsonProtocol.sparkEventToJson).foreach { event => + assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent]) + } } test("pipeline read/write events") { @@ -182,6 +234,11 @@ class MLEventsSuite case e => fail(s"Unexpected event thrown: $e") } } + // Test if they can be ser/de via JSON protocol. + assert(events.nonEmpty) + events.map(JsonProtocol.sparkEventToJson).foreach { event => + assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent]) + } events.clear() val pipelineReader = Pipeline.read @@ -202,6 +259,11 @@ class MLEventsSuite case e => fail(s"Unexpected event thrown: $e") } } + // Test if they can be ser/de via JSON protocol. + assert(events.nonEmpty) + events.map(JsonProtocol.sparkEventToJson).foreach { event => + assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent]) + } } } @@ -230,6 +292,11 @@ class MLEventsSuite case e => fail(s"Unexpected event thrown: $e") } } + // Test if they can be ser/de via JSON protocol. + assert(events.nonEmpty) + events.map(JsonProtocol.sparkEventToJson).foreach { event => + assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent]) + } events.clear() val pipelineModelReader = PipelineModel.read @@ -250,6 +317,11 @@ class MLEventsSuite case e => fail(s"Unexpected event thrown: $e") } } + // Test if they can be ser/de via JSON protocol. + assert(events.nonEmpty) + events.map(JsonProtocol.sparkEventToJson).foreach { event => + assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent]) + } } } }