Skip to content

Commit

Permalink
Optionally disable logging in the data sampler to support predict_ste…
Browse files Browse the repository at this point in the history
…p which does not support logging

Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 13, 2024
1 parent 7bb4271 commit 49549af
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def __init__(
rampup_batch_size: Optional[List[int]] = None,
dataloader_type: Literal["single", "cyclic"] = "single",
init_consumed_samples: int = 0,
output_log: bool = True,
):
self.seq_len = seq_len
self.output_log = output_log
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
Expand Down Expand Up @@ -94,28 +96,31 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul

# TODO: Add consumed samples
consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_consumed_samples)

pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
if self.output_log:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)

self.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
if self.output_log:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
self.if_first_step = 1

@property
Expand Down

0 comments on commit 49549af

Please sign in to comment.