-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DDP native mixed precision #92882
DDP native mixed precision #92882
Conversation
Per title Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92882
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 6f296ed: BROKEN TRUNK - The following jobs failed but were present on the merge base 44d8e6c:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Per title Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) ghstack-source-id: 178271673 Pull Request resolved: #92882
Per title Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 178775648 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Per title Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 179617838 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 179693212 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for putting them together!! left some comments, also would you please use ".data" as less as possible?
torch/nn/parallel/distributed.py
Outdated
keep_batchnorm_fp32: bool = True | ||
keep_layernorm_fp32: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems that layernorm is fine to run low precision, maybe default them to be False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also did not see any code to handle keep_batchnorm_fp32/keep_layernorm_fp32 right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add this as follow up work.
torch/nn/parallel/distributed.py
Outdated
# Do not cast DDP ignored parameters. | ||
if hasattr(param, '_ddp_ignored') and param._ddp_ignored: | ||
continue | ||
_alloc_storage(param._mp_param, param.data.size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: param.size(), let's try to use .data as less as possible
_alloc_storage(param._mp_param, param.data.size()) | ||
# copy() implicitly casts to low precision | ||
with torch.no_grad(): | ||
param._mp_param.copy_(param.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, could we remove .data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove .data
hook = grad_acc.register_hook( | ||
functools.partial(self._fire_reducer_autograd_hook, p._idx) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hooks registered on ctor are not deregistered, right? even though they are not fired
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not deregistered. If we'd like, we could have an API to de-register them if mixed precision is enabled.
torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py
Outdated
Show resolved
Hide resolved
torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py
Outdated
Show resolved
Hide resolved
p._ddp_mp_hook_state[1].remove() | ||
delattr(p, '_ddp_mp_hook_state') | ||
if not p.requires_grad and not hasattr(p, '_ddp_ignored'): | ||
p.data = p._fp_param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the upcast has already been done in above codes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for params that don't require grad. They don't get the above code run since their autograd hooks do not fire, but we still need to upcast them as they were downcast before the forward (we need to do that since they may participate in computation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(also commented on L60)
with torch.cuda.stream(stream): | ||
fut.wait() | ||
bucket.buffer().div_(process_group.size()) | ||
ret_fut.set_result(bucket.buffer()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the allreduced bucket will be used to update grads in "finalize_bucket_dense()" later on, so the below upcast codes seem not working? the upcast codes possibly should be done after "finalize_bucket_dense()"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm this is a good point, let me look into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The upcast code actually runs before finalize_bucket_dense
:
- This code is kicked off in the comm hook callback, and run on a separate stream. This stream is waited on by the below queue_callback to the autograd engine.
- C++ finalize_backward also happens as an autograd engine final callback. However it runs after the python callback, since we insert the python callback first, which is why the upcast code here works as expected.
- It is not great to rely on autograd engine callback running order. We can look into consolidating and inserting only a single callback always for DDP.
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 181345031 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 182069766 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good, since the code path is gated, it should be safe to land. Left some comments to remove .data usage, can be done in follow up PR.
Another important thing, I just realized that buckets and grads are mostly always detached because of the down cast and up cast in every iteration's forward, we should make grads point to buckets at the end of forward to avoid coping btw buckets and grads, otherwise it may have performance regression for some workloads. This can be added in the follow up PR though.
# free storage for mp param as it will be allocated again in next | ||
# forward pass. | ||
_free_storage(p._mp_param) | ||
p.grad.data = p.grad.to(p.data.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: p.grad = p.grad.to(p.dtype), avoid using .data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline to leave as is, otherwise when set_to_none=False, the gradients could potentially become unlinked to buckets here
if hasattr(buf, '_ddp_ignored') and buf._ddp_ignored: | ||
continue | ||
|
||
buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we do 'buf = buf.to()'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't work since it will just point buf to the fp16 tensor, but the buffer .buffers() won't be updated. FSDP does the same thing.
If we'd like to avoid .data usage here, we could de-register the float32 buffer, and register a fp16 buffer.
torch/nn/parallel/distributed.py
Outdated
|
||
if not hasattr(param, '_mp_param'): | ||
param._mp_param = torch.zeros_like( | ||
param.data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove .data here
# back to at the end of forward / backward. | ||
param._fp_param = param.data | ||
|
||
def _cast_forward_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move to util helper
_alloc_storage(param._mp_param, param.data.size()) | ||
# copy() implicitly casts to low precision | ||
with torch.no_grad(): | ||
param._mp_param.copy_(param.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove .data
torch/nn/parallel/distributed.py
Outdated
# is saved and .grad field is set to None, bypassing | ||
# this issue. | ||
if param.grad is not None: | ||
param.grad.data = param.grad.to(self.mixed_precision.param_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove .data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to above, this would de-link the pointers for grad_as_bucket_view (when set_grad_none=False), so decided to keep as is for now.
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 182379630 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 182562307 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) [ghstack-poisoned]
Pull Request resolved: #92882 Per title ghstack-source-id: 182670902 Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
@pytorchbot merge -f "CI unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) Pull Request resolved: pytorch/pytorch#92882 Approved by: https://github.com/zhaojuanmao
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows: 1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves. 5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs. 6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback. 7. DDP Ignored parameters are not touched. Follow-ups: 1. Unify comm hooks and make it work with apply optimizer in backward 2. implement keep_low_precision_grads, 3. allow BN, LN, or custom units to run in reduced precision, 4. support for cast_forward_inputs 5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs 6. Integrate this with replicate() API. 7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 8. Entirely unused modules probably don't need to be cast. Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/) Pull Request resolved: pytorch/pytorch#92882 Approved by: https://github.com/zhaojuanmao
Stack from ghstack (oldest at bottom):
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:
_mp_param
and_fp_param
variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.Follow-ups:
Differential Revision: D42515803