Skip to content

Commit

Permalink
Use 1-based indexing for epochs everywhere (#1053)
Browse files Browse the repository at this point in the history
Summary:
We are somewhat inconsistent in whether we're using 0-based or 1-based indexing for epochs. This should fix things to be 0-based internally, with logging and checkpoint naming still using 1-based indexing.
Pull Request resolved: fairinternal/fairseq-py#1053

Reviewed By: spencerp

Differential Revision: D20160715

Pulled By: myleott

fbshipit-source-id: 4ed94f9c371e1bfe29bcfa087fa6756507d6e627
  • Loading branch information
myleott authored and facebook-github-bot committed Mar 5, 2020
1 parent 4171b83 commit aa79bb9
Show file tree
Hide file tree
Showing 26 changed files with 63 additions and 116 deletions.
2 changes: 1 addition & 1 deletion examples/roberta/commonsense_qa/commonsense_qa_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def setup_task(cls, args, **kwargs):

return cls(args, vocab)

def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
Expand Down
4 changes: 2 additions & 2 deletions examples/roberta/wsc/wsc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space)
mask[mask_start:mask_start + mask_size] = 1
return toks, mask

def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
Expand Down Expand Up @@ -281,7 +281,7 @@ def setup_task(cls, args, **kwargs):

return cls(args, vocab)

def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
Expand Down
2 changes: 1 addition & 1 deletion fairseq/benchmark/dummy_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setup_task(cls, args, **kwargs):

return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Expand Down
2 changes: 1 addition & 1 deletion fairseq/benchmark/dummy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setup_task(cls, args, **kwargs):

return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Expand Down
8 changes: 7 additions & 1 deletion fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def is_better(a, b):
if os.path.lexists(old_chk):
os.remove(old_chk)


def load_checkpoint(args, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
Expand Down Expand Up @@ -150,7 +151,7 @@ def load_checkpoint(args, trainer, **passthrough_args):
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=0, load_dataset=True, **passthrough_args
epoch=1, load_dataset=True, **passthrough_args
)

trainer.lr_step(epoch_itr.epoch)
Expand Down Expand Up @@ -349,6 +350,11 @@ def _upgrade_state_dict(state):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
state["extra_state"]["train_iterator"]["epoch"] = max(
getattr(state["extra_state"]["train_iterator"], "epoch", 1),
1,
)

# set any missing default values in the task, model or other registries
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
Expand Down
2 changes: 0 additions & 2 deletions fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
Expand Down Expand Up @@ -96,7 +95,6 @@
'ResamplingDataset',
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedDataset',
'ShardedIterator',
'SortDataset',
'StripTokenDataset',
Expand Down
19 changes: 12 additions & 7 deletions fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,18 @@ def load_state_dict(self, state_dict):

class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
self, dataset, epoch=0, num_shards=1, shard_id=0,
self, dataset, epoch=1, num_shards=1, shard_id=0,
):
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
self.epoch = epoch
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self._current_epoch_iterator = None
self.num_shards = num_shards
self.shard_id = shard_id

def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self.epoch += 1
if self._current_epoch_iterator is not None and self.end_of_epoch():
self.epoch += 1
self.dataset.set_epoch(self.epoch)
self._current_epoch_iterator = CountingIterator(
iterable=ShardedIterator(
Expand Down Expand Up @@ -165,12 +166,12 @@ class EpochBatchIterator(EpochBatchIterating):
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
(default: 1).
"""

def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
num_workers=0, epoch=0,
num_workers=0, epoch=1,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
Expand All @@ -181,7 +182,7 @@ def __init__(
self.shard_id = shard_id
self.num_workers = num_workers

self.epoch = epoch
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.shuffle = True
self._cur_epoch_itr = None
self._next_epoch_itr = None
Expand All @@ -204,7 +205,8 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
if self._cur_epoch_itr is not None and self.end_of_epoch():
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
Expand Down Expand Up @@ -244,6 +246,9 @@ def load_state_dict(self, state_dict):
shuffle=state_dict.get('shuffle', True),
offset=itr_pos,
)
if self._next_epoch_itr is None:
# we finished the epoch, increment epoch counter
self.epoch += 1

def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):

Expand Down
4 changes: 2 additions & 2 deletions fairseq/data/resampling_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ResamplingDataset(BaseWrapperDataset):
batch_by_size (bool): whether or not to batch by sequence length
(default: True).
seed (int): RNG seed to use (default: 0).
epoch (int): starting epoch number (default: 0).
epoch (int): starting epoch number (default: 1).
"""

def __init__(
Expand All @@ -42,7 +42,7 @@ def __init__(
size_ratio=1.0,
batch_by_size=True,
seed=0,
epoch=0,
epoch=1,
):
super().__init__(dataset)

Expand Down
60 changes: 0 additions & 60 deletions fairseq/data/sharded_dataset.py

This file was deleted.

7 changes: 4 additions & 3 deletions fairseq/tasks/cross_lingual_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _load_single_lang_dataset(self, split, epoch):

paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]

for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
Expand Down Expand Up @@ -136,8 +136,9 @@ def _load_single_lang_dataset(self, split, epoch):

return dataset, sizes

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
Expand Down Expand Up @@ -165,5 +166,5 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):

self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
logger.info('{} {} {} examples'.format(
utils.split_paths(self.args.data)[epoch], split, len(self.datasets[split]))
utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split]))
)
5 changes: 2 additions & 3 deletions fairseq/tasks/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,15 @@ def setup_task(cls, args, **kwargs):
args.shuffle_instance = False
return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""

paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)

dataset = data_utils.load_indexed_dataset(
Expand Down
4 changes: 2 additions & 2 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_batch_iterator(
num_shards=1,
shard_id=0,
num_workers=0,
epoch=0,
epoch=1,
):
"""
Get an iterator that yields batches of data from the given dataset.
Expand All @@ -143,7 +143,7 @@ def get_batch_iterator(
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
(default: 1).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
Expand Down
4 changes: 2 additions & 2 deletions fairseq/tasks/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def build_model(self, args):

return model

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
Expand All @@ -157,7 +157,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
paths = utils.split_paths(self.args.data)
assert len(paths) > 0

data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)

dataset = data_utils.load_indexed_dataset(
Expand Down
5 changes: 3 additions & 2 deletions fairseq/tasks/legacy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,17 @@ def setup_task(cls, args, **kwargs):

return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False):
def load_dataset(self, split, epoch=1, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
loaded_datasets = []

paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
logger.info("data_path", data_path)

for k in itertools.count():
Expand Down
4 changes: 2 additions & 2 deletions fairseq/tasks/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def setup_task(cls, args, **kwargs):
logger.info('dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)

dataset = data_utils.load_indexed_dataset(
Expand Down
6 changes: 3 additions & 3 deletions fairseq/tasks/multilingual_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def _get_sample_prob(self, dataset_lens):
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob

def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""

paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)

if self.langs is None:
Expand Down
Loading

0 comments on commit aa79bb9

Please sign in to comment.