Skip to content

Commit

Permalink
Fix preserve_rng_state for activation checkpointing (#4690)
Browse files Browse the repository at this point in the history
  • Loading branch information
YangFei1990 authored and JackCaoG committed Mar 20, 2024
1 parent c923e8f commit 49efb2b
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
36 changes: 36 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.distributed.data_parallel as dp
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
Expand Down Expand Up @@ -2334,6 +2335,41 @@ def test_aten_move_scalar_cuda_to_xla(self):
self._test_move_tensor_cuda_to_xla(torch.tensor(42))


class SimpleModelWithDropout(torch.nn.Module):

def __init__(self):
super().__init__()
self.x = torch.nn.Linear(128, 128)
self.dropout = torch.nn.Dropout(p=0.1)
self.to_save = []

def save_output(self, output):
self.to_save.append(output.detach().cpu())

def forward(self, inp):
x = self.x(inp)
output = self.dropout(x)
xm.add_step_closure(self.save_output, args=(output,), run_async=False)
return output


class TestActivationCheckpoint(test_utils.XlaTestCase):

def test_dropout(self):
device = xm.xla_device()
model = SimpleModelWithDropout().to(device)
model = checkpoint_module(model)
_input = torch.randn(128, 128, requires_grad=True)
_input = _input.to(device)
output = model(_input)
output = torch.sum(output)
output.backward()
xm.mark_step()
same_output = torch.allclose(model.to_save[0], model.to_save[1])
self.assertTrue(same_output,
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import io
import itertools
import logging
Expand Down Expand Up @@ -1263,6 +1264,28 @@ def get_rng_state(device=None):
return torch_xla._XLAC._xla_get_rng_seed(str(device) if device else '')


@contextlib.contextmanager
def fork_rng(device=None, enabled=True):
"""
Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in.
Args:
device (string, optional): The device where the RNG state needs to be set. If missing the default device seed will be set.
enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it.
"""
if not enabled:
yield
return

if device is None:
device = torch_xla._XLAC._xla_get_default_device()
xla_rng_state = get_rng_state(device=device)

try:
yield
finally:
set_rng_state(xla_rng_state, device=device)


def get_memory_info(device):
"""Retrieves the device memory information.
Expand Down
23 changes: 14 additions & 9 deletions torch_xla/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def forward(ctx, run_function, preserve_rng_state, *args):
"cache_enabled": torch.is_autocast_cache_enabled()
}
if preserve_rng_state:
ctx.fwd_xla_state = xm.get_rng_state()
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
Expand Down Expand Up @@ -143,17 +144,21 @@ def backward(ctx, *args):
rng_devices = ctx.fwd_gpu_devices
xm.optimization_barrier_(
CheckpointFunction._extract_tensors_from_list(inputs + list(args)))
# torch.random.fork_rng will handle the cpu and gpu seed
# xm.fork_rng will handle the xla device seed
with torch.random.fork_rng(
devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
with xm.fork_rng():
if ctx.preserve_rng_state:
xm.set_rng_state(ctx.fwd_xla_state)
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
Expand Down

0 comments on commit 49efb2b

Please sign in to comment.