Skip to content

Commit

Permalink
Support subnetwork hooks requesting early stopping.
Browse files Browse the repository at this point in the history
We do so by wrapping the subnetwork SessionRunHooks and passing a temporary SessionRunContext to the hooks methods as to intercept early stopping requests and handling them inside of AdaNet instead of letting the MonitoredTrainingSession receive the request directly.

PiperOrigin-RevId: 269669834
  • Loading branch information
cweill committed Sep 17, 2019
1 parent ba54c19 commit 92b955a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

# Current version (0.8.0.dev)
* Under development.
* Support subnetwork hooks requesting early stopping.
* Adding AdaNet replay. The ability to rerun training without having to determine the best candidate for the iteration. A list of best indices from the previous run is provided and honored by AdaNet.
* TODO: Add official Keras Model support, including Keras layers, Sequential, and Model subclasses for defining subnetworks.
* Introduced `adanet.ensemble.MeanEnsembler` with a basic implementation for taking the mean of logits of subnetworks. This also supports including the mean of last_layer (helpful if subnetworks have same configurations) in the `predictions` and `export_outputs` of the EstimatorSpec.
Expand Down
4 changes: 3 additions & 1 deletion adanet/core/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ class Estimator(tf.estimator.Estimator):
TensorBoard for each subnetwork. Disable to reduce memory and disk usage
per run.
global_step_combiner_fn: Function for combining each subnetwork's
iteration step into the global step.
iteration step into the global step. By default it is the average of all
subnetwork iteration steps, which may affect the global_steps/sec as
subnetworks early stop and no longer increase their iteration step.
max_iterations: Integer maximum number of AdaNet iterations (a.k.a. rounds)
of generating new subnetworks and ensembles, training them, and evaluating
them against the current best ensemble. When :code:`None`, AdaNet will
Expand Down
29 changes: 29 additions & 0 deletions adanet/core/estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,13 @@ def evaluate(self, sess, ensemble_metrics):
return losses


class _EarlyStoppingHook(tf_compat.SessionRunHook):
"""Hook that immediately requests training to stop."""

def after_run(self, run_context, run_values):
run_context.request_stop()


class EstimatorTest(tu.AdanetTestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -985,6 +992,28 @@ class EstimatorTest(tu.AdanetTestCase):
2,
"want_global_step":
300,
},
{
"testcase_name":
"early_stopping_subnetwork",
"subnetwork_generator":
SimpleGenerator([
_DNNBuilder("dnn"),
_DNNBuilder("dnn2", subnetwork_hooks=[_EarlyStoppingHook()])
]),
"max_iteration_steps":
100,
"max_steps":
200,
"want_loss":
0.2958503,
# Since one subnetwork stops after 1 step and global step is the
# mean of iteration steps, global step will be incremented at half
# the rate.
"want_iteration":
3,
"want_global_step":
200,
})
def test_lifecycle(self,
subnetwork_generator,
Expand Down
14 changes: 12 additions & 2 deletions adanet/core/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,23 @@ def after_create_session(self, session, coord):

def before_run(self, run_context):
if self._train_manager.should_train(self._spec):
# Use a tmp run context to intercept if the hook requests stop.
tmp_run_context = tf_compat.v1.train.SessionRunContext(
run_context.original_args, run_context.session)
with self._session_run_context():
return self._hook.before_run(run_context)
return self._hook.before_run(tmp_run_context)
if tmp_run_context.stop_requested:
self._train_manager.request_stop(self._spec, "Stop requested.")

def after_run(self, run_context, run_values):
if self._train_manager.should_train(self._spec):
# Use a tmp run context to intercept if the hook requests stop.
tmp_run_context = tf_compat.v1.train.SessionRunContext(
run_context.original_args, run_context.session)
with self._session_run_context():
self._hook.after_run(run_context, run_values)
self._hook.after_run(tmp_run_context, run_values)
if tmp_run_context.stop_requested:
self._train_manager.request_stop(self._spec, "Stop requested.")

def end(self, session):
with self._session_run_context():
Expand Down

0 comments on commit 92b955a

Please sign in to comment.