Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix preserve_rng_state for activation checkpointing #4690

Merged
merged 6 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.distributed.data_parallel as dp
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
Expand Down Expand Up @@ -2334,6 +2335,41 @@ def test_aten_move_scalar_cuda_to_xla(self):
self._test_move_tensor_cuda_to_xla(torch.tensor(42))


class SimpleModelWithDropout(torch.nn.Module):

def __init__(self):
super().__init__()
self.x = torch.nn.Linear(128, 128)
self.dropout = torch.nn.Dropout(p=0.1)
self.to_save = []

def save_output(self, output):
self.to_save.append(output.detach().cpu())

def forward(self, inp):
x = self.x(inp)
output = self.dropout(x)
xm.add_step_closure(self.save_output, args=(output,), run_async=False)
return output


class TestActivationCheckpoint(test_utils.XlaTestCase):

def test_dropout(self):
device = xm.xla_device()
model = SimpleModelWithDropout().to(device)
model = checkpoint_module(model)
_input = torch.randn(128, 128, requires_grad=True)
_input = _input.to(device)
output = model(_input)
output = torch.sum(output)
output.backward()
xm.mark_step()
same_output = torch.allclose(model.to_save[0], model.to_save[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this to_save?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_save is the container to hold the output tensor. With activation checkpointing the FWD will run twice, this container can capture both tensors. Check line 2352.

self.assertTrue(same_output,
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import io
import itertools
import logging
Expand Down Expand Up @@ -1263,6 +1264,28 @@ def get_rng_state(device=None):
return torch_xla._XLAC._xla_get_rng_seed(str(device) if device else '')


@contextlib.contextmanager
def fork_rng(device=None, enabled=True):
"""
Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in.
Args:
device (string, optional): The device where the RNG state needs to be set. If missing the default device seed will be set.
enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it.
"""
if not enabled:
yield
return

if device is None:
device = torch_xla._XLAC._xla_get_default_device()
xla_rng_state = get_rng_state(device=device)

try:
yield
finally:
set_rng_state(xla_rng_state, device=device)


def get_memory_info(device):
"""Retrieves the device memory information.

Expand Down
23 changes: 14 additions & 9 deletions torch_xla/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def forward(ctx, run_function, preserve_rng_state, *args):
"cache_enabled": torch.is_autocast_cache_enabled()
}
if preserve_rng_state:
ctx.fwd_xla_state = xm.get_rng_state()
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
Expand Down Expand Up @@ -143,17 +144,21 @@ def backward(ctx, *args):
rng_devices = ctx.fwd_gpu_devices
xm.optimization_barrier_(
CheckpointFunction._extract_tensors_from_list(inputs + list(args)))
# torch.random.fork_rng will handle the cpu and gpu seed
# xm.fork_rng will handle the xla device seed
with torch.random.fork_rng(
devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
with xm.fork_rng():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason not to pass the rng_devices and ctx.preserve_rng_state ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess upstream seed is handled by torch.random.fork_rng? through I am not sure why it doesn't work with pytorch/xla...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah upstream seed is handled by torch.random.fork_rng. It will fork torch seed but somehow it won't set XLA's RNG. This seed torch_xla._XLAC._xla_get_rng_seed(str(device) is it independent to torch seed? How torch XLA in general handle RNGs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not change the previous behavior, i.e. upstream seed will still be maintained as it was (check code below). I simply add another preserve RNG states.

if ctx.preserve_rng_state:
xm.set_rng_state(ctx.fwd_xla_state)
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
Expand Down
Loading