Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP native mixed precision #92882

Closed
wants to merge 20 commits into from

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Jan 24, 2023

Stack from ghstack (oldest at bottom):

Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

  1. In DDP init, we save _mp_param and _fp_param variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.
  2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously.
  3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision.
  4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
  5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
  6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
  7. DDP Ignored parameters are not touched.

Follow-ups:

  1. Unify comm hooks and make it work with apply optimizer in backward
  2. implement keep_low_precision_grads,
  3. allow BN, LN, or custom units to run in reduced precision,
  4. support for cast_forward_inputs
  5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
  6. Integrate this with replicate() API.
  7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order.
  8. Entirely unused modules probably don't need to be cast.

Differential Revision: D42515803

Per title

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92882

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit 6f296ed:

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base 44d8e6c:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

rohan-varma added a commit that referenced this pull request Jan 24, 2023
Per title

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

ghstack-source-id: 178271673
Pull Request resolved: #92882
@rohan-varma rohan-varma changed the title DDP native mixed precision [WIP] DDP native mixed precision Jan 24, 2023
Per title

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jan 30, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 178775648

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Per title

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
@rohan-varma rohan-varma changed the title [WIP] DDP native mixed precision DDP native mixed precision Feb 8, 2023
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Feb 8, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 179617838

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Feb 8, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 179693212

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for putting them together!! left some comments, also would you please use ".data" as less as possible?

torch/nn/parallel/distributed.py Outdated Show resolved Hide resolved
Comment on lines 101 to 102
keep_batchnorm_fp32: bool = True
keep_layernorm_fp32: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems that layernorm is fine to run low precision, maybe default them to be False?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also did not see any code to handle keep_batchnorm_fp32/keep_layernorm_fp32 right now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add this as follow up work.

# Do not cast DDP ignored parameters.
if hasattr(param, '_ddp_ignored') and param._ddp_ignored:
continue
_alloc_storage(param._mp_param, param.data.size())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: param.size(), let's try to use .data as less as possible

_alloc_storage(param._mp_param, param.data.size())
# copy() implicitly casts to low precision
with torch.no_grad():
param._mp_param.copy_(param.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, could we remove .data?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove .data

torch/nn/parallel/distributed.py Outdated Show resolved Hide resolved
torch/nn/parallel/distributed.py Show resolved Hide resolved
Comment on lines +973 to +975
hook = grad_acc.register_hook(
functools.partial(self._fire_reducer_autograd_hook, p._idx)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hooks registered on ctor are not deregistered, right? even though they are not fired

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not deregistered. If we'd like, we could have an API to de-register them if mixed precision is enabled.

p._ddp_mp_hook_state[1].remove()
delattr(p, '_ddp_mp_hook_state')
if not p.requires_grad and not hasattr(p, '_ddp_ignored'):
p.data = p._fp_param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the upcast has already been done in above codes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for params that don't require grad. They don't get the above code run since their autograd hooks do not fire, but we still need to upcast them as they were downcast before the forward (we need to do that since they may participate in computation)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also commented on L60)

with torch.cuda.stream(stream):
fut.wait()
bucket.buffer().div_(process_group.size())
ret_fut.set_result(bucket.buffer())
Copy link
Contributor

@zhaojuanmao zhaojuanmao Feb 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the allreduced bucket will be used to update grads in "finalize_bucket_dense()" later on, so the below upcast codes seem not working? the upcast codes possibly should be done after "finalize_bucket_dense()"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this is a good point, let me look into it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The upcast code actually runs before finalize_bucket_dense:

  1. This code is kicked off in the comm hook callback, and run on a separate stream. This stream is waited on by the below queue_callback to the autograd engine.
  2. C++ finalize_backward also happens as an autograd engine final callback. However it runs after the python callback, since we insert the python callback first, which is why the upcast code here works as expected.
  3. It is not great to rely on autograd engine callback running order. We can look into consolidating and inserting only a single callback always for DDP.

Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Feb 27, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 181345031

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 6, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 182069766

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good, since the code path is gated, it should be safe to land. Left some comments to remove .data usage, can be done in follow up PR.

Another important thing, I just realized that buckets and grads are mostly always detached because of the down cast and up cast in every iteration's forward, we should make grads point to buckets at the end of forward to avoid coping btw buckets and grads, otherwise it may have performance regression for some workloads. This can be added in the follow up PR though.

torch/csrc/distributed/c10d/reducer.cpp Outdated Show resolved Hide resolved
# free storage for mp param as it will be allocated again in next
# forward pass.
_free_storage(p._mp_param)
p.grad.data = p.grad.to(p.data.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: p.grad = p.grad.to(p.dtype), avoid using .data

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline to leave as is, otherwise when set_to_none=False, the gradients could potentially become unlinked to buckets here

torch/nn/parallel/distributed.py Show resolved Hide resolved
if hasattr(buf, '_ddp_ignored') and buf._ddp_ignored:
continue

buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we do 'buf = buf.to()'?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work since it will just point buf to the fp16 tensor, but the buffer .buffers() won't be updated. FSDP does the same thing.

If we'd like to avoid .data usage here, we could de-register the float32 buffer, and register a fp16 buffer.


if not hasattr(param, '_mp_param'):
param._mp_param = torch.zeros_like(
param.data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove .data here

# back to at the end of forward / backward.
param._fp_param = param.data

def _cast_forward_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move to util helper

_alloc_storage(param._mp_param, param.data.size())
# copy() implicitly casts to low precision
with torch.no_grad():
param._mp_param.copy_(param.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove .data

# is saved and .grad field is set to None, bypassing
# this issue.
if param.grad is not None:
param.grad.data = param.grad.to(self.mixed_precision.param_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove .data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to above, this would de-link the pointers for grad_as_bucket_view (when set_grad_none=False), so decided to keep as is for now.

Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
@rohan-varma rohan-varma added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2023
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 9, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 182379630

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 10, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 182562307

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed. 
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously. 
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision. 
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads, 
3. allow BN, LN, or custom units to run in reduced precision, 
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order. 
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 13, 2023
Pull Request resolved: #92882

Per title
ghstack-source-id: 182670902

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
@rohan-varma
Copy link
Member Author

@pytorchbot merge -f "CI unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 23, 2023
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously.
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision.
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads,
3. allow BN, LN, or custom units to run in reduced precision,
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order.
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Pull Request resolved: pytorch/pytorch#92882
Approved by: https://github.com/zhaojuanmao
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously.
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision.
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads,
3. allow BN, LN, or custom units to run in reduced precision,
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order.
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Pull Request resolved: pytorch/pytorch#92882
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/631/head branch June 8, 2023 18:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (c10d) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants