From c066ef78eac7561e6404046472c8c3e3f178f731 Mon Sep 17 00:00:00 2001 From: Wil Kong Date: Fri, 16 Aug 2024 01:52:24 +0800 Subject: [PATCH] Fix The Issue With PL 2.4.0 (#10137) * Drop customized register_key since PL 2.4.0 doesn't invoke additional host & device sync in it. * Remove register_key. --- nemo/utils/callbacks/cuda_graph.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/nemo/utils/callbacks/cuda_graph.py b/nemo/utils/callbacks/cuda_graph.py index ec0650a90e7d..e18a2a3d3b6c 100644 --- a/nemo/utils/callbacks/cuda_graph.py +++ b/nemo/utils/callbacks/cuda_graph.py @@ -139,16 +139,6 @@ def to_tensor(self, value, name): return value -def register_key(self, key, meta, value): - # PyTorch Lightning creates all metrics on GPU, but creating the metric on - # its input device is prefered. - # Refer to: https://github.com/Lightning-AI/pytorch-lightning/blob/2.0.7/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L409 - metric = _ResultMetric(meta, isinstance(value, torch.Tensor)) - device = value.device if isinstance(value, torch.Tensor) else self.device - metric = metric.to(device) - self[key] = metric - - def update_metrics(self, key, value, batch_size): # PyTorch Lightning always move all metrics to GPU, but moving the metric to # its input device is prefered. @@ -374,8 +364,6 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") # Use smart metrics to avoid syncs LightningModule.__orig_to_tensor__ = LightningModule._LightningModule__to_tensor LightningModule._LightningModule__to_tensor = to_tensor - _ResultCollection.__orig_register_key__ = _ResultCollection.register_key - _ResultCollection.register_key = register_key _ResultCollection.__orig_update_metrics__ = _ResultCollection.update_metrics _ResultCollection.update_metrics = update_metrics @@ -409,8 +397,6 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - LightningModule._LightningModule__to_tensor = LightningModule.__orig_to_tensor__ del LightningModule.__orig_to_tensor__ - _ResultCollection.register_key = _ResultCollection.__orig_register_key__ - del _ResultCollection.__orig_register_key__ _ResultCollection.update_metrics = _ResultCollection.__orig_update_metrics__ del _ResultCollection.__orig_update_metrics__