diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index 2d4b93f8..da6a0edb 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -157,6 +157,10 @@ def _compile(): if info.get("normalizer") or normalizer: dataset = dataset.normalize(normalizer) + n_epochs = info.get("epochs") or epochs + dataset = dataset.repeat(n_epochs).batch(batch_size) + steps_per_epoch = dataset.get_steps_per_epoch() + with self.strategy.scope(): # grow the networks by one (2^x) resolution if resolution > self.current_resolution_: @@ -164,8 +168,6 @@ def _compile(): self.model_.discriminator.add_resolution() _compile() - steps_per_epoch = dataset.get_steps_per_epoch() - # save_best_only is set to False as it is an adversarial loss model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( str(model_dir), @@ -180,7 +182,7 @@ def _compile(): print("Transition phase") self.model_.fit( - dataset, + dataset.dataset, phase="transition", resolution=resolution, steps_per_epoch=steps_per_epoch, # necessary for repeat dataset @@ -189,7 +191,7 @@ def _compile(): print("Resolution phase") self.model_.fit( - dataset, + dataset.dataset, phase="resolution", resolution=resolution, steps_per_epoch=steps_per_epoch,