diff --git a/acme/tf/savers.py b/acme/tf/savers.py index f3b623e71d..1b33478afd 100644 --- a/acme/tf/savers.py +++ b/acme/tf/savers.py @@ -150,23 +150,38 @@ def save(self, force: bool = False) -> bool: time.time() - self._last_saved < 60 * self._time_delta_minutes): return False + checkpoint_manager: tf.train.CheckpointManager = self.checkpoint_manager # Save any checkpoints. - logging.info('Saving checkpoint: %s', self._checkpoint_manager.directory) - self._checkpoint_manager.save() + logging.info('Saving checkpoint: %s', checkpoint_manager.directory) + checkpoint_manager.save() self._last_saved = time.time() return True def restore(self): + """Restore from most recent checkpoint.""" + # Restore from the most recent checkpoint (if it exists). - checkpoint_to_restore = self._checkpoint_manager.latest_checkpoint + checkpoint_to_restore = self.checkpoint_manager.latest_checkpoint logging.info('Attempting to restore checkpoint: %s', checkpoint_to_restore) self._checkpoint.restore(checkpoint_to_restore) @property def directory(self): - return self._checkpoint_manager.directory + return self.checkpoint_manager.directory + + @property + def checkpoint_manager(self) -> tf.train.CheckpointManager: + if not self._enable_checkpointing: + raise ValueError( + 'Check-point not enabled. No checkpoint manager available.' + ) + + # At this point, _enable_checkpointing is true, so _checkpoint_manager + # should not be None. + assert self._checkpoint_manager is not None + return self._checkpoint_manager class CheckpointingRunner(core.Worker): @@ -332,6 +347,8 @@ def __init__(self): @tf.function def __call__(self, *args, **kwargs): + if self._module is None: + raise ValueError('_module not set') return self._module(*args, **kwargs) @property