-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix preserve_rng_state for activation checkpointing #4690
Fix preserve_rng_state for activation checkpointing #4690
Conversation
Thanks! Mostly LGTM. Can you add a test case to maybe https://github.com/pytorch/xla/blob/master/test/test_operations.py ? You can compare the result with xla device and cpu device. this way we won't regress this. |
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
outputs = ctx.run_function(*detached_inputs) | ||
with xm.fork_rng(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason not to pass the rng_devices
and ctx.preserve_rng_state
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess upstream seed is handled by torch.random.fork_rng
? through I am not sure why it doesn't work with pytorch/xla...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah upstream seed is handled by torch.random.fork_rng
. It will fork torch seed but somehow it won't set XLA's RNG. This seed torch_xla._XLAC._xla_get_rng_seed(str(device)
is it independent to torch seed? How torch XLA in general handle RNGs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not change the previous behavior, i.e. upstream seed will still be maintained as it was (check code below). I simply add another preserve RNG states.
output = torch.sum(output) | ||
output.backward() | ||
xm.mark_step() | ||
same_output = torch.allclose(model.to_save[0], model.to_save[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this to_save?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to_save is the container to hold the output tensor. With activation checkpointing the FWD will run twice, this container can capture both tensors. Check line 2352.
test/test_operations.py
Outdated
same_output = torch.allclose(model.to_save[0], model.to_save[1]) | ||
if not same_output: | ||
print(f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") | ||
self.assertTrue(same_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can do something similar to
self.assertTrue(same_output, f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome didn't know could do that. Updating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly, LGTM. Please address the comments.
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
outputs = ctx.run_function(*detached_inputs) | ||
with xm.fork_rng(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?
I will take care of the backport |
In the activation checkpointing implementation we have the
preserve_rng_state
option, if it is set toTrue
, activation checkpointing should use the same RNG state for the two forward runs in a single step. Consider the following test script with activation checkpoint and a dropout op in the model:If everything works right
same_output
should beTrue
. However we observed without XLA it works correctlyBut with XLA it is wrong
This PR fixed this issue by also saving/loading the XLA's RNG state in the activation checkpointing implementation. After the fix the output matches between the 2 forwards.