From aa2c3b39a6f437b3a96b1795d42b0fcdb3fa54a5 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Thu, 7 Nov 2019 11:45:18 -0800 Subject: [PATCH] Enable usage of batch_by_size for other code paths than just train (#17) --- fairseq/tasks/fairseq_task.py | 14 +++++++++++--- fairseq/tasks/translation.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 25daf58d22..5db2a94a30 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -138,9 +138,17 @@ def get_batch_iterator( ) # create mini-batches with given size constraints - batch_sampler = data_utils.batch_by_size_tpu( - indices, dataset.num_tokens, self.args.input_shapes - ) + if getattr(self.args, 'use_gpu', True): + batch_sampler = data_utils.batch_by_size( + indices, dataset.num_tokens, max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + else: + batch_sampler = data_utils.batch_by_size_tpu( + indices, dataset.num_tokens, + getattr(self.args, 'input_shapes', None) + ) # return a reusable, sharded iterator return iterators.EpochBatchIterator( diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 70b9a8e972..d76c574272 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -188,7 +188,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, - input_shapes=self.args.input_shapes, + input_shapes=getattr(self.args, 'input_shapes', None), ) def build_dataset_for_inference(self, src_tokens, src_lengths):