diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 25daf58d22..5db2a94a30 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -138,9 +138,17 @@ def get_batch_iterator( ) # create mini-batches with given size constraints - batch_sampler = data_utils.batch_by_size_tpu( - indices, dataset.num_tokens, self.args.input_shapes - ) + if getattr(self.args, 'use_gpu', True): + batch_sampler = data_utils.batch_by_size( + indices, dataset.num_tokens, max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + else: + batch_sampler = data_utils.batch_by_size_tpu( + indices, dataset.num_tokens, + getattr(self.args, 'input_shapes', None) + ) # return a reusable, sharded iterator return iterators.EpochBatchIterator( diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 70b9a8e972..d76c574272 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -188,7 +188,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, - input_shapes=self.args.input_shapes, + input_shapes=getattr(self.args, 'input_shapes', None), ) def build_dataset_for_inference(self, src_tokens, src_lengths):