diff --git a/examples/roberta/commonsense_qa/commonsense_qa_task.py b/examples/roberta/commonsense_qa/commonsense_qa_task.py index 39a16d2948..274e8d39aa 100644 --- a/examples/roberta/commonsense_qa/commonsense_qa_task.py +++ b/examples/roberta/commonsense_qa/commonsense_qa_task.py @@ -66,7 +66,7 @@ def setup_task(cls, args, **kwargs): return cls(args, vocab) - def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py index 757dfb82fa..54d245c583 100644 --- a/examples/roberta/wsc/wsc_task.py +++ b/examples/roberta/wsc/wsc_task.py @@ -101,7 +101,7 @@ def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space) mask[mask_start:mask_start + mask_size] = 1 return toks, mask - def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: @@ -281,7 +281,7 @@ def setup_task(cls, args, **kwargs): return cls(args, vocab) - def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index 017e8c3d79..81831c4529 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -42,7 +42,7 @@ def setup_task(cls, args, **kwargs): return cls(args, dictionary) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index 721fcddbc0..56e5ea5fbe 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -53,7 +53,7 @@ def setup_task(cls, args, **kwargs): return cls(args, dictionary) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 8a297d4509..77dbb29b5e 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -109,6 +109,7 @@ def is_better(a, b): if os.path.lexists(old_chk): os.remove(old_chk) + def load_checkpoint(args, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. @@ -150,7 +151,7 @@ def load_checkpoint(args, trainer, **passthrough_args): epoch_itr.load_state_dict(itr_state) else: epoch_itr = trainer.get_train_iterator( - epoch=0, load_dataset=True, **passthrough_args + epoch=1, load_dataset=True, **passthrough_args ) trainer.lr_step(epoch_itr.epoch) @@ -349,6 +350,11 @@ def _upgrade_state_dict(state): state["args"].dataset_impl = "raw" elif getattr(state["args"], "lazy_load", False): state["args"].dataset_impl = "lazy" + # epochs start at 1 + state["extra_state"]["train_iterator"]["epoch"] = max( + getattr(state["extra_state"]["train_iterator"], "epoch", 1), + 1, + ) # set any missing default values in the task, model or other registries registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 7aedec0d05..f844564dbc 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -38,7 +38,6 @@ from .resampling_dataset import ResamplingDataset from .roll_dataset import RollDataset from .round_robin_zip_datasets import RoundRobinZipDatasets -from .sharded_dataset import ShardedDataset from .sort_dataset import SortDataset from .strip_token_dataset import StripTokenDataset from .subsample_dataset import SubsampleDataset @@ -96,7 +95,6 @@ 'ResamplingDataset', 'RightPadDataset', 'RoundRobinZipDatasets', - 'ShardedDataset', 'ShardedIterator', 'SortDataset', 'StripTokenDataset', diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 2a1c30a75f..4a9fcaeeae 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -100,17 +100,18 @@ def load_state_dict(self, state_dict): class StreamingEpochBatchIterator(EpochBatchIterating): def __init__( - self, dataset, epoch=0, num_shards=1, shard_id=0, + self, dataset, epoch=1, num_shards=1, shard_id=0, ): assert isinstance(dataset, torch.utils.data.IterableDataset) self.dataset = dataset - self.epoch = epoch + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self._current_epoch_iterator = None self.num_shards = num_shards self.shard_id = shard_id def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): - self.epoch += 1 + if self._current_epoch_iterator is not None and self.end_of_epoch(): + self.epoch += 1 self.dataset.set_epoch(self.epoch) self._current_epoch_iterator = CountingIterator( iterable=ShardedIterator( @@ -165,12 +166,12 @@ class EpochBatchIterator(EpochBatchIterating): loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from - (default: 0). + (default: 1). """ def __init__( self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - num_workers=0, epoch=0, + num_workers=0, epoch=1, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -181,7 +182,7 @@ def __init__( self.shard_id = shard_id self.num_workers = num_workers - self.epoch = epoch + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = True self._cur_epoch_itr = None self._next_epoch_itr = None @@ -204,7 +205,8 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): self._cur_epoch_itr = self._next_epoch_itr self._next_epoch_itr = None else: - self.epoch += 1 + if self._cur_epoch_itr is not None and self.end_of_epoch(): + self.epoch += 1 self._cur_epoch_itr = self._get_iterator_for_epoch( self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, ) @@ -244,6 +246,9 @@ def load_state_dict(self, state_dict): shuffle=state_dict.get('shuffle', True), offset=itr_pos, ) + if self._next_epoch_itr is None: + # we finished the epoch, increment epoch counter + self.epoch += 1 def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0): diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py index 8a2c48f83f..2967916163 100644 --- a/fairseq/data/resampling_dataset.py +++ b/fairseq/data/resampling_dataset.py @@ -31,7 +31,7 @@ class ResamplingDataset(BaseWrapperDataset): batch_by_size (bool): whether or not to batch by sequence length (default: True). seed (int): RNG seed to use (default: 0). - epoch (int): starting epoch number (default: 0). + epoch (int): starting epoch number (default: 1). """ def __init__( @@ -42,7 +42,7 @@ def __init__( size_ratio=1.0, batch_by_size=True, seed=0, - epoch=0, + epoch=1, ): super().__init__(dataset) diff --git a/fairseq/data/sharded_dataset.py b/fairseq/data/sharded_dataset.py deleted file mode 100644 index f7ef4e36cd..0000000000 --- a/fairseq/data/sharded_dataset.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -import os -import random - -from . import BaseWrapperDataset -from fairseq.data import data_utils - - -class ShardedDataset(BaseWrapperDataset): - """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS. - - Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch - - """ - - def __init__( - self, - dictionary, - dataset_impl: str, - path: str, - split: str, - epoch: int, - name: str = None, - combine: bool = False, - seed: int = 0, - ): - self._name = name if name is not None else os.path.basename(path) - num_shards = 0 - for i in itertools.count(): - if not os.path.exists(os.path.join(path, "shard" + str(i))): - break - num_shards += 1 - - if num_shards > 0 and split == "train": - random.seed(seed ^ epoch) - shard = random.randint(0, num_shards - 1) - split_path = os.path.join(path, "shard" + str(shard), split) - else: - split_path = os.path.join(path, split) - if os.path.isdir(split_path): - split_path = os.path.join(split_path, split) - - dataset = data_utils.load_indexed_dataset( - split_path, dictionary, dataset_impl, combine=combine - ) - if dataset is None: - raise FileNotFoundError( - "Dataset not found: {} ({})".format(split, split_path) - ) - - super().__init__(dataset) - - @property - def name(self): - return self._name diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py index 0644a08355..3589492f11 100644 --- a/fairseq/tasks/cross_lingual_lm.py +++ b/fairseq/tasks/cross_lingual_lm.py @@ -102,7 +102,7 @@ def _load_single_lang_dataset(self, split, epoch): paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') @@ -136,8 +136,9 @@ def _load_single_lang_dataset(self, split, epoch): return dataset, sizes - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. + Args: split (str): name of the split (e.g., train, valid, test) """ @@ -165,5 +166,5 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): self.datasets[split] = MultiCorpusSampledDataset(dataset_map) logger.info('{} {} {} examples'.format( - utils.split_paths(self.args.data)[epoch], split, len(self.datasets[split])) + utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split])) ) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index a65cc0a626..28beb517f2 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -104,16 +104,15 @@ def setup_task(cls, args, **kwargs): args.shuffle_instance = False return cls(args, dictionary) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 2d231531b6..48bc8cdff2 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -116,7 +116,7 @@ def get_batch_iterator( num_shards=1, shard_id=0, num_workers=0, - epoch=0, + epoch=1, ): """ Get an iterator that yields batches of data from the given dataset. @@ -143,7 +143,7 @@ def get_batch_iterator( loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from - (default: 0). + (default: 1). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 07540d3e32..14fe1959f1 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -148,7 +148,7 @@ def build_model(self, args): return model - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -157,7 +157,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( diff --git a/fairseq/tasks/legacy_masked_lm.py b/fairseq/tasks/legacy_masked_lm.py index 21c57b4b6f..40e2724953 100644 --- a/fairseq/tasks/legacy_masked_lm.py +++ b/fairseq/tasks/legacy_masked_lm.py @@ -78,8 +78,9 @@ def setup_task(cls, args, **kwargs): return cls(args, dictionary) - def load_dataset(self, split, epoch=0, combine=False): + def load_dataset(self, split, epoch=1, combine=False): """Load a given dataset split. + Args: split (str): name of the split (e.g., train, valid, test) """ @@ -87,7 +88,7 @@ def load_dataset(self, split, epoch=0, combine=False): paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] logger.info("data_path", data_path) for k in itertools.count(): diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index fc37f78ecf..8089abf7d8 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -75,7 +75,7 @@ def setup_task(cls, args, **kwargs): logger.info('dictionary: {} types'.format(len(dictionary))) return cls(args, dictionary) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -83,7 +83,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( diff --git a/fairseq/tasks/multilingual_denoising.py b/fairseq/tasks/multilingual_denoising.py index 04c5c9a9e3..84cc340472 100644 --- a/fairseq/tasks/multilingual_denoising.py +++ b/fairseq/tasks/multilingual_denoising.py @@ -87,15 +87,15 @@ def _get_sample_prob(self, dataset_lens): smoothed_prob = smoothed_prob / smoothed_prob.sum() return smoothed_prob - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. + Args: split (str): name of the split (e.g., train, valid, test) """ - paths = self.args.data.split(':') assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) if self.langs is None: diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py index 816724c05c..248724bd56 100644 --- a/fairseq/tasks/multilingual_masked_lm.py +++ b/fairseq/tasks/multilingual_masked_lm.py @@ -116,7 +116,7 @@ def _get_sample_prob(self, dataset_lens): smoothed_prob = smoothed_prob / smoothed_prob.sum() return smoothed_prob - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -124,7 +124,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] languages = sorted( name for name in os.listdir(data_path) @@ -295,7 +295,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, + seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, ): # Recreate epoch iterator every epoch cause the underlying # datasets are dynamic due to sampling. diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 48bb4df2e8..9a1315e1db 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -187,12 +187,11 @@ def alter_dataset_langtok(self, lang_pair_dataset, new_tgt_bos=new_tgt_bos, ) - def load_dataset(self, split, epoch=0, **kwargs): + def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" - paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index b67e5fca65..bf770bfe15 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -136,12 +136,11 @@ def setup_task(cls, args, **kwargs): dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) return cls(args, dicts, training) - def load_dataset(self, split, epoch=0, **kwargs): + def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" - paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index f05dacb617..ce81da96fd 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -232,7 +232,7 @@ def setup_task(cls, args, **kwargs): return cls(args, src_dict, tgt_dict) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -240,7 +240,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 5057d2f446..bf925a8669 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -53,7 +53,7 @@ def __init__(self, args, src_dict, tgt_dict): d.add_symbol('[{}]'.format(l)) d.add_symbol('') - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -61,7 +61,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ paths = self.args.data.split(':') assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 0fcaa40d9d..093c340cab 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -29,7 +29,7 @@ def add_args(parser): default='random_delete', choices=['random_delete', 'random_mask', 'no_noise', 'full_mask']) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -37,7 +37,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 740169e79e..c6a25680a5 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -477,7 +477,7 @@ def zero_grad(self): self.optimizer.zero_grad() def lr_step(self, epoch, val_loss=None): - """Adjust the learning rate based on the validation loss.""" + """Adjust the learning rate at the end of the epoch.""" self.lr_scheduler.step(epoch, val_loss) # prefer updating the LR based on the number of steps return self.lr_step_update() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index e8b1e00ef0..21771ff2f0 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -56,7 +56,7 @@ def main(args, init_distributed=False): # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): - task.load_dataset(valid_sub_split, combine=False, epoch=0) + task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) @@ -90,7 +90,7 @@ def main(args, init_distributed=False): while ( lr > args.min_lr and ( - epoch_itr.epoch < max_epoch + epoch_itr.epoch <= max_epoch # allow resuming training from the final checkpoint or epoch_itr._next_epoch_itr is not None ) @@ -148,7 +148,7 @@ def train(args, trainer, task, epoch_itr): # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.epoch >= args.curriculum), + shuffle=(epoch_itr.epoch > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 527e975788..6e7eaf5447 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -61,7 +61,7 @@ def main(args, override_args=None): for subset in args.valid_subset.split(','): try: - task.load_dataset(subset, combine=False, epoch=0) + task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) diff --git a/tests/test_train.py b/tests/test_train.py index 7f2eb40204..734d0e8601 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -67,7 +67,6 @@ def setUp(self): [p.start() for p in self.applied_patches] def test_load_partial_checkpoint(self): - with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) @@ -110,7 +109,7 @@ def test_load_full_checkpoint(self): def test_load_no_checkpoint(self): with contextlib.redirect_stdout(StringIO()): - trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) + trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) self.patches['os.path.isfile'].return_value = False