From 49efb2b6060d1c5bda30da82c1a7579a52a54a3e Mon Sep 17 00:00:00 2001 From: Fei <33940270+YangFei1990@users.noreply.github.com> Date: Wed, 20 Mar 2024 14:45:11 -0700 Subject: [PATCH] Fix preserve_rng_state for activation checkpointing (#4690) --- test/test_operations.py | 36 +++++++++++++++++++++++++++++++++++ torch_xla/core/xla_model.py | 23 ++++++++++++++++++++++ torch_xla/utils/checkpoint.py | 23 +++++++++++++--------- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index e284cac5620..6b28f1fbe89 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -32,6 +32,7 @@ import torch_xla.core.xla_builder as xb import torch_xla.core.xla_op_registry as xor import torch_xla.distributed.data_parallel as dp +from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear import torch_xla.debug.metrics as met import torch_xla.debug.model_comparator as mc @@ -2334,6 +2335,41 @@ def test_aten_move_scalar_cuda_to_xla(self): self._test_move_tensor_cuda_to_xla(torch.tensor(42)) +class SimpleModelWithDropout(torch.nn.Module): + + def __init__(self): + super().__init__() + self.x = torch.nn.Linear(128, 128) + self.dropout = torch.nn.Dropout(p=0.1) + self.to_save = [] + + def save_output(self, output): + self.to_save.append(output.detach().cpu()) + + def forward(self, inp): + x = self.x(inp) + output = self.dropout(x) + xm.add_step_closure(self.save_output, args=(output,), run_async=False) + return output + + +class TestActivationCheckpoint(test_utils.XlaTestCase): + + def test_dropout(self): + device = xm.xla_device() + model = SimpleModelWithDropout().to(device) + model = checkpoint_module(model) + _input = torch.randn(128, 128, requires_grad=True) + _input = _input.to(device) + output = model(_input) + output = torch.sum(output) + output.backward() + xm.mark_step() + same_output = torch.allclose(model.to_save[0], model.to_save[1]) + self.assertTrue(same_output, + f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") + + if __name__ == '__main__': torch.set_default_dtype(torch.float32) torch.manual_seed(42) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 28622fdafc2..9b2cd139e88 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1,3 +1,4 @@ +import contextlib import io import itertools import logging @@ -1263,6 +1264,28 @@ def get_rng_state(device=None): return torch_xla._XLAC._xla_get_rng_seed(str(device) if device else '') +@contextlib.contextmanager +def fork_rng(device=None, enabled=True): + """ + Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in. + Args: + device (string, optional): The device where the RNG state needs to be set. If missing the default device seed will be set. + enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it. + """ + if not enabled: + yield + return + + if device is None: + device = torch_xla._XLAC._xla_get_default_device() + xla_rng_state = get_rng_state(device=device) + + try: + yield + finally: + set_rng_state(xla_rng_state, device=device) + + def get_memory_info(device): """Retrieves the device memory information. diff --git a/torch_xla/utils/checkpoint.py b/torch_xla/utils/checkpoint.py index 16cc1ffdfb2..d1ad3f70e0e 100644 --- a/torch_xla/utils/checkpoint.py +++ b/torch_xla/utils/checkpoint.py @@ -88,6 +88,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): "cache_enabled": torch.is_autocast_cache_enabled() } if preserve_rng_state: + ctx.fwd_xla_state = xm.get_rng_state() ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their @@ -143,17 +144,21 @@ def backward(ctx, *args): rng_devices = ctx.fwd_gpu_devices xm.optimization_barrier_( CheckpointFunction._extract_tensors_from_list(inputs + list(args))) + # torch.random.fork_rng will handle the cpu and gpu seed + # xm.fork_rng will handle the xla device seed with torch.random.fork_rng( devices=rng_devices, enabled=ctx.preserve_rng_state): - if ctx.preserve_rng_state: - torch.set_rng_state(ctx.fwd_cpu_state) - if ctx.had_cuda_in_fwd: - set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) - detached_inputs = detach_variable(tuple(inputs)) - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ - torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): - outputs = ctx.run_function(*detached_inputs) + with xm.fork_rng(): + if ctx.preserve_rng_state: + xm.set_rng_state(ctx.fwd_xla_state) + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,)