Skip to content

Commit

Permalink
Fix scripting causing false positive deprecation warnings (#10555)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
2 people authored and lexierule committed Nov 16, 2021
1 parent 122e503 commit 9e45024
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))


- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470))
- Fixed scripting causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/pull/10470), [#10555](https://github.com/PyTorchLightning/pytorch-lightning/pull/10555))


- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def log_graph(self, model: "pl.LightningModule", input_array=None):

if input_array is not None:
input_array = model._apply_batch_transfer_handler(input_array)
model._running_torchscript = True
self.experiment.add_graph(model, input_array)
model._running_torchscript = False
else:
rank_zero_warn(
"Could not log computational graph since the"
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,25 @@ def to_tensor(x):
args = apply_to_collection(args, dtype=(int, float), function=to_tensor)
return args

def training_step(self, *args, **kwargs):
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.TRAINING](*args, **kwargs)
poptorch_model = self.poptorch_models[stage]
self.lightning_module._running_torchscript = True
out = poptorch_model(*args, **kwargs)
self.lightning_module._running_torchscript = False
return out

def training_step(self, *args, **kwargs):
return self._step(RunningStage.TRAINING, *args, **kwargs)

def validation_step(self, *args, **kwargs):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.VALIDATING](*args, **kwargs)
return self._step(RunningStage.VALIDATING, *args, **kwargs)

def test_step(self, *args, **kwargs):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.TESTING](*args, **kwargs)
return self._step(RunningStage.TESTING, *args, **kwargs)

def predict_step(self, *args, **kwargs):
args = self._prepare_input(args)
return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs)
return self._step(RunningStage.PREDICTING, *args, **kwargs)

def teardown(self) -> None:
# undo dataloader patching
Expand Down

0 comments on commit 9e45024

Please sign in to comment.