diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index d7fd9a79372f2..248f8035b1a88 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -96,8 +96,7 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul self.prev_global_batch_size = self.current_global_batch_size - # TODO: Add consumed samples - consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_consumed_samples) + consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step) if self.output_log: # You may need to turn off logging, for example when doing trainer.predict(model, data) pl_module.log(