Skip to content

Commit

Permalink
Fix The Issue With PL 2.4.0 (#10137)
Browse files Browse the repository at this point in the history
* Drop customized register_key since PL 2.4.0 doesn't invoke additional host & device sync in it.

* Remove register_key.
  • Loading branch information
alpha0422 committed Aug 15, 2024
1 parent 05ced1e commit c066ef7
Showing 1 changed file with 0 additions and 14 deletions.
14 changes: 0 additions & 14 deletions nemo/utils/callbacks/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,6 @@ def to_tensor(self, value, name):
return value


def register_key(self, key, meta, value):
# PyTorch Lightning creates all metrics on GPU, but creating the metric on
# its input device is prefered.
# Refer to: https://github.com/Lightning-AI/pytorch-lightning/blob/2.0.7/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L409
metric = _ResultMetric(meta, isinstance(value, torch.Tensor))
device = value.device if isinstance(value, torch.Tensor) else self.device
metric = metric.to(device)
self[key] = metric


def update_metrics(self, key, value, batch_size):
# PyTorch Lightning always move all metrics to GPU, but moving the metric to
# its input device is prefered.
Expand Down Expand Up @@ -374,8 +364,6 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
# Use smart metrics to avoid syncs
LightningModule.__orig_to_tensor__ = LightningModule._LightningModule__to_tensor
LightningModule._LightningModule__to_tensor = to_tensor
_ResultCollection.__orig_register_key__ = _ResultCollection.register_key
_ResultCollection.register_key = register_key
_ResultCollection.__orig_update_metrics__ = _ResultCollection.update_metrics
_ResultCollection.update_metrics = update_metrics

Expand Down Expand Up @@ -409,8 +397,6 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -

LightningModule._LightningModule__to_tensor = LightningModule.__orig_to_tensor__
del LightningModule.__orig_to_tensor__
_ResultCollection.register_key = _ResultCollection.__orig_register_key__
del _ResultCollection.__orig_register_key__
_ResultCollection.update_metrics = _ResultCollection.__orig_update_metrics__
del _ResultCollection.__orig_update_metrics__

Expand Down

0 comments on commit c066ef7

Please sign in to comment.