Skip to content

Commit

Permalink
[SPARK-26818][ML] Make MLEvents JSON ser/de safe
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently, it looks it's not going to cause any virtually effective problem apparently (if I didn't misread the codes).

I see one place that JSON formatted events are being used.

https://github.com/apache/spark/blob/ec506bd30c2ca324c12c9ec811764081c2eb8c42/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L148

It's okay because it just logs when the exception is ignorable

https://github.com/apache/spark/blob/9690eba16efe6d25261934d8b73a221972b684f3/core/src/main/scala/org/apache/spark/util/ListenerBus.scala#L111

I guess it should be best to stay safe - I don't want this unstable experimental feature breaks anything in any case. It also disables `logEvent` in `SparkListenerEvent` for the same reason.

This is also to match SQL execution events side:

https://github.com/apache/spark/blob/ca545f79410a464ef24e3986fac225f53bb2ef02/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L41-L57

to make ML events JSON ser/de safe.

## How was this patch tested?

Manually tested, and unit tests were added.

Closes #23728 from HyukjinKwon/SPARK-26818.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Feb 3, 2019
1 parent 96c6c29 commit dfb8809
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 38 deletions.
81 changes: 63 additions & 18 deletions mllib/src/main/scala/org/apache/spark/ml/events.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml

import com.fasterxml.jackson.annotation.JsonIgnore

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Unstable
import org.apache.spark.internal.Logging
Expand All @@ -29,53 +31,84 @@ import org.apache.spark.sql.{DataFrame, Dataset}
* after each operation (the event should document this).
*
* @note This is supported via [[Pipeline]] and [[PipelineModel]].
* @note This is experimental and unstable. Do not use this unless you fully
* understand what `Unstable` means.
*/
@Unstable
sealed trait MLEvent extends SparkListenerEvent
sealed trait MLEvent extends SparkListenerEvent {
// Do not log ML events in event log. It should be revisited to see
// how it works with history server.
protected[spark] override def logEvent: Boolean = false
}

/**
* Event fired before `Transformer.transform`.
*/
@Unstable
case class TransformStart(transformer: Transformer, input: Dataset[_]) extends MLEvent
case class TransformStart() extends MLEvent {
@JsonIgnore var transformer: Transformer = _
@JsonIgnore var input: Dataset[_] = _
}

/**
* Event fired after `Transformer.transform`.
*/
@Unstable
case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends MLEvent
case class TransformEnd() extends MLEvent {
@JsonIgnore var transformer: Transformer = _
@JsonIgnore var output: Dataset[_] = _
}

/**
* Event fired before `Estimator.fit`.
*/
@Unstable
case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: Dataset[_]) extends MLEvent
case class FitStart[M <: Model[M]]() extends MLEvent {
@JsonIgnore var estimator: Estimator[M] = _
@JsonIgnore var dataset: Dataset[_] = _
}

/**
* Event fired after `Estimator.fit`.
*/
@Unstable
case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends MLEvent
case class FitEnd[M <: Model[M]]() extends MLEvent {
@JsonIgnore var estimator: Estimator[M] = _
@JsonIgnore var model: M = _
}

/**
* Event fired before `MLReader.load`.
*/
@Unstable
case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends MLEvent
case class LoadInstanceStart[T](path: String) extends MLEvent {
@JsonIgnore var reader: MLReader[T] = _
}

/**
* Event fired after `MLReader.load`.
*/
@Unstable
case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent
case class LoadInstanceEnd[T]() extends MLEvent {
@JsonIgnore var reader: MLReader[T] = _
@JsonIgnore var instance: T = _
}

/**
* Event fired before `MLWriter.save`.
*/
@Unstable
case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent
case class SaveInstanceStart(path: String) extends MLEvent {
@JsonIgnore var writer: MLWriter = _
}

/**
* Event fired after `MLWriter.save`.
*/
@Unstable
case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent
case class SaveInstanceEnd(path: String) extends MLEvent {
@JsonIgnore var writer: MLWriter = _
}

/**
* A small trait that defines some methods to send [[org.apache.spark.ml.MLEvent]].
Expand All @@ -91,46 +124,58 @@ private[ml] trait MLEvents extends Logging {

def withFitEvent[M <: Model[M]](
estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
val startEvent = FitStart(estimator, dataset)
val startEvent = FitStart[M]()
startEvent.estimator = estimator
startEvent.dataset = dataset
logEvent(startEvent)
listenerBus.post(startEvent)
val model: M = func
val endEvent = FitEnd(estimator, model)
val endEvent = FitEnd[M]()
endEvent.estimator = estimator
endEvent.model = model
logEvent(endEvent)
listenerBus.post(endEvent)
model
}

def withTransformEvent(
transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = {
val startEvent = TransformStart(transformer, input)
val startEvent = TransformStart()
startEvent.transformer = transformer
startEvent.input = input
logEvent(startEvent)
listenerBus.post(startEvent)
val output: DataFrame = func
val endEvent = TransformEnd(transformer, output)
val endEvent = TransformEnd()
endEvent.transformer = transformer
endEvent.output = output
logEvent(endEvent)
listenerBus.post(endEvent)
output
}

def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = {
val startEvent = LoadInstanceStart(reader, path)
val startEvent = LoadInstanceStart[T](path)
startEvent.reader = reader
logEvent(startEvent)
listenerBus.post(startEvent)
val instance: T = func
val endEvent = LoadInstanceEnd(reader, instance)
val endEvent = LoadInstanceEnd[T]()
endEvent.reader = reader
endEvent.instance = instance
logEvent(endEvent)
listenerBus.post(endEvent)
instance
}

def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = {
listenerBus.post(SaveInstanceEnd(writer, path))
val startEvent = SaveInstanceStart(writer, path)
val startEvent = SaveInstanceStart(path)
startEvent.writer = writer
logEvent(startEvent)
listenerBus.post(startEvent)
func
val endEvent = SaveInstanceEnd(writer, path)
val endEvent = SaveInstanceEnd(path)
endEvent.writer = writer
logEvent(endEvent)
listenerBus.post(endEvent)
}
Expand Down
112 changes: 92 additions & 20 deletions mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MLWri
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql._
import org.apache.spark.util.JsonProtocol


class MLEventsSuite
Expand Down Expand Up @@ -107,20 +108,48 @@ class MLEventsSuite
.setStages(Array(estimator1, transformer1, estimator2))
assert(events.isEmpty)
val pipelineModel = pipeline.fit(dataset1)
val expected =
FitStart(pipeline, dataset1) ::
FitStart(estimator1, dataset1) ::
FitEnd(estimator1, model1) ::
TransformStart(model1, dataset1) ::
TransformEnd(model1, dataset2) ::
TransformStart(transformer1, dataset2) ::
TransformEnd(transformer1, dataset3) ::
FitStart(estimator2, dataset3) ::
FitEnd(estimator2, model2) ::
FitEnd(pipeline, pipelineModel) :: Nil

val event0 = FitStart[PipelineModel]()
event0.estimator = pipeline
event0.dataset = dataset1
val event1 = FitStart[MyModel]()
event1.estimator = estimator1
event1.dataset = dataset1
val event2 = FitEnd[MyModel]()
event2.estimator = estimator1
event2.model = model1
val event3 = TransformStart()
event3.transformer = model1
event3.input = dataset1
val event4 = TransformEnd()
event4.transformer = model1
event4.output = dataset2
val event5 = TransformStart()
event5.transformer = transformer1
event5.input = dataset2
val event6 = TransformEnd()
event6.transformer = transformer1
event6.output = dataset3
val event7 = FitStart[MyModel]()
event7.estimator = estimator2
event7.dataset = dataset3
val event8 = FitEnd[MyModel]()
event8.estimator = estimator2
event8.model = model2
val event9 = FitEnd[PipelineModel]()
event9.estimator = pipeline
event9.model = pipelineModel

val expected = Seq(
event0, event1, event2, event3, event4, event5, event6, event7, event8, event9)
eventually(timeout(10 seconds), interval(1 second)) {
assert(events === expected)
}
// Test if they can be ser/de via JSON protocol.
assert(events.nonEmpty)
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
}
}

test("pipeline model transform events") {
Expand All @@ -144,18 +173,41 @@ class MLEventsSuite
"pipeline0", Array(transformer1, model, transformer2))
assert(events.isEmpty)
val output = newPipelineModel.transform(dataset1)
val expected =
TransformStart(newPipelineModel, dataset1) ::
TransformStart(transformer1, dataset1) ::
TransformEnd(transformer1, dataset2) ::
TransformStart(model, dataset2) ::
TransformEnd(model, dataset3) ::
TransformStart(transformer2, dataset3) ::
TransformEnd(transformer2, dataset4) ::
TransformEnd(newPipelineModel, output) :: Nil

val event0 = TransformStart()
event0.transformer = newPipelineModel
event0.input = dataset1
val event1 = TransformStart()
event1.transformer = transformer1
event1.input = dataset1
val event2 = TransformEnd()
event2.transformer = transformer1
event2.output = dataset2
val event3 = TransformStart()
event3.transformer = model
event3.input = dataset2
val event4 = TransformEnd()
event4.transformer = model
event4.output = dataset3
val event5 = TransformStart()
event5.transformer = transformer2
event5.input = dataset3
val event6 = TransformEnd()
event6.transformer = transformer2
event6.output = dataset4
val event7 = TransformEnd()
event7.transformer = newPipelineModel
event7.output = output

val expected = Seq(event0, event1, event2, event3, event4, event5, event6, event7)
eventually(timeout(10 seconds), interval(1 second)) {
assert(events === expected)
}
// Test if they can be ser/de via JSON protocol.
assert(events.nonEmpty)
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
}
}

test("pipeline read/write events") {
Expand All @@ -182,6 +234,11 @@ class MLEventsSuite
case e => fail(s"Unexpected event thrown: $e")
}
}
// Test if they can be ser/de via JSON protocol.
assert(events.nonEmpty)
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
}

events.clear()
val pipelineReader = Pipeline.read
Expand All @@ -202,6 +259,11 @@ class MLEventsSuite
case e => fail(s"Unexpected event thrown: $e")
}
}
// Test if they can be ser/de via JSON protocol.
assert(events.nonEmpty)
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
}
}
}

Expand Down Expand Up @@ -230,6 +292,11 @@ class MLEventsSuite
case e => fail(s"Unexpected event thrown: $e")
}
}
// Test if they can be ser/de via JSON protocol.
assert(events.nonEmpty)
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
}

events.clear()
val pipelineModelReader = PipelineModel.read
Expand All @@ -250,6 +317,11 @@ class MLEventsSuite
case e => fail(s"Unexpected event thrown: $e")
}
}
// Test if they can be ser/de via JSON protocol.
assert(events.nonEmpty)
events.map(JsonProtocol.sparkEventToJson).foreach { event =>
assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
}
}
}
}

0 comments on commit dfb8809

Please sign in to comment.