Skip to content

Commit

Permalink
Enable usage of batch_by_size for other code paths than just train (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
taylanbil authored Nov 7, 2019
1 parent 92f19a2 commit aa2c3b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
14 changes: 11 additions & 3 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,17 @@ def get_batch_iterator(
)

# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size_tpu(
indices, dataset.num_tokens, self.args.input_shapes
)
if getattr(self.args, 'use_gpu', True):
batch_sampler = data_utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
else:
batch_sampler = data_utils.batch_by_size_tpu(
indices, dataset.num_tokens,
getattr(self.args, 'input_shapes', None)
)

# return a reusable, sharded iterator
return iterators.EpochBatchIterator(
Expand Down
2 changes: 1 addition & 1 deletion fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
input_shapes=self.args.input_shapes,
input_shapes=getattr(self.args, 'input_shapes', None),
)

def build_dataset_for_inference(self, src_tokens, src_lengths):
Expand Down

0 comments on commit aa2c3b3

Please sign in to comment.