-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix preserve_rng_state for activation checkpointing #4690
Changes from all commits
29becf3
05ad6e4
392fe6c
d3fb5a6
355efb2
78c08dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): | |
"cache_enabled": torch.is_autocast_cache_enabled() | ||
} | ||
if preserve_rng_state: | ||
ctx.fwd_xla_state = xm.get_rng_state() | ||
ctx.fwd_cpu_state = torch.get_rng_state() | ||
# Don't eagerly initialize the cuda context by accident. | ||
# (If the user intends that the context is initialized later, within their | ||
|
@@ -143,17 +144,21 @@ def backward(ctx, *args): | |
rng_devices = ctx.fwd_gpu_devices | ||
xm.optimization_barrier_( | ||
CheckpointFunction._extract_tensors_from_list(inputs + list(args))) | ||
# torch.random.fork_rng will handle the cpu and gpu seed | ||
# xm.fork_rng will handle the xla device seed | ||
with torch.random.fork_rng( | ||
devices=rng_devices, enabled=ctx.preserve_rng_state): | ||
if ctx.preserve_rng_state: | ||
torch.set_rng_state(ctx.fwd_cpu_state) | ||
if ctx.had_cuda_in_fwd: | ||
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) | ||
detached_inputs = detach_variable(tuple(inputs)) | ||
with torch.enable_grad(), \ | ||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
outputs = ctx.run_function(*detached_inputs) | ||
with xm.fork_rng(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason not to pass the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess upstream seed is handled by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah upstream seed is handled by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not change the previous behavior, i.e. upstream seed will still be maintained as it was (check code below). I simply add another preserve RNG states. |
||
if ctx.preserve_rng_state: | ||
xm.set_rng_state(ctx.fwd_xla_state) | ||
torch.set_rng_state(ctx.fwd_cpu_state) | ||
if ctx.had_cuda_in_fwd: | ||
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) | ||
detached_inputs = detach_variable(tuple(inputs)) | ||
with torch.enable_grad(), \ | ||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
outputs = ctx.run_function(*detached_inputs) | ||
|
||
if isinstance(outputs, torch.Tensor): | ||
outputs = (outputs,) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this to_save?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to_save is the container to hold the output tensor. With activation checkpointing the FWD will run twice, this container can capture both tensors. Check line 2352.