Skip to content

Commit

Permalink
get_seq_order_for_epoch: cleaned up calculation of random seed
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-wilken committed Aug 9, 2021
1 parent ab55b41 commit 5ac3d99
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,10 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
"""
partition_epoch = self.partition_epoch or 1
repeat_epoch = self.repeat_epoch or 1
if not epoch:
epoch = 1
full_epoch = epoch
if partition_epoch > 1:
full_epoch = (epoch - 1) // partition_epoch + 1
assert num_seqs > 0
if self._seq_order_seq_lens_file:
get_seq_len = self._get_seq_order_seq_lens_by_idx

if self.seq_ordering == 'default':
seq_index = range(num_seqs)
elif self.seq_ordering.startswith("default_every_n:"):
Expand All @@ -406,7 +402,7 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
tmp = self.seq_ordering.split(':')
nth = int(tmp[1]) if len(tmp) > 1 else 1
# Keep this deterministic! Use fixed seed.
rnd_seed = (full_epoch - 1) // nth + 1
rnd_seed = self._get_random_seed_for_epoch(epoch=epoch, num_epochs_fixed=nth)
numpy.random.seed(rnd_seed)
seq_index = numpy.random.permutation(num_seqs)
elif self.seq_ordering.startswith('sort_bin_shuffle'):
Expand All @@ -418,7 +414,7 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
nth = 1
else:
nth = int(tmp[1])
rnd_seed = ((full_epoch - 1) // nth + 1) if full_epoch else 1
rnd_seed = self._get_random_seed_for_epoch(epoch=epoch, num_epochs_fixed=nth)
numpy.random.seed(rnd_seed)
seq_index = numpy.random.permutation(num_seqs).tolist()
seq_index.sort(key=get_seq_len) # Sort by length, starting with shortest.
Expand Down Expand Up @@ -454,7 +450,7 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
nth = 1
else:
nth = int(tmp[1])
rnd_seed = ((full_epoch - 1) // nth + 1) if full_epoch else 1
rnd_seed = self._get_random_seed_for_epoch(epoch=epoch, num_epochs_fixed=nth)
numpy.random.seed(rnd_seed)
seq_index = numpy.random.permutation(num_seqs)
out_index = []
Expand All @@ -468,6 +464,7 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
seq_index = out_index
else:
assert False, "invalid batching specified: " + self.seq_ordering

if self.unique_seq_tags:
# Note: This is as generic as possible, but requires that get_all_tags is implemented.
all_seq_tags = self.get_all_tags()
Expand Down Expand Up @@ -514,13 +511,16 @@ def _apply_partition_epoch(cls, seq_index, partition_epoch, epoch):

return seq_index

def _get_random_seed_for_epoch(self, epoch):
def _get_random_seed_for_epoch(self, epoch, num_epochs_fixed=1):
"""
:param int|None epoch:
:param int num_epochs_fixed: keep random seed fixed for n subsequent full epochs
:rtype: int
"""
partition_epoch = self.partition_epoch or 1
full_epoch = epoch or 1
if num_epochs_fixed > 1:
full_epoch = (full_epoch - 1) // num_epochs_fixed + 1
if partition_epoch > 1:
full_epoch = (full_epoch - 1) // partition_epoch + 1
return full_epoch + self.random_seed_offset
Expand Down

0 comments on commit 5ac3d99

Please sign in to comment.