-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharded Accelerator 1/n: Expose clip gradients to plugins via abstract class #4639
Conversation
…e within accelerator to clip gradients
Need to sort out APEX fp16 (with level set to O2 all parameters are in FP16, I need to change the epsilon to match this) EDIT: is done :) |
…gradients), add override for O2 support apex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great ! Just one question about clipping for cpu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Smaller comments, LGTM
…tor. Default to standard clip function for vanilla torch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR !
def _clip_gradients(self, optimizer, grad_clip_val): | ||
if self.trainer.amp_backend: | ||
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer) | ||
else: | ||
parameters = model.parameters() | ||
|
||
max_norm = grad_clip_val | ||
norm_type = float(2.0) | ||
|
||
if isinstance(parameters, torch.Tensor): | ||
parameters = [parameters] | ||
parameters = list(filter(lambda p: p.grad is not None, parameters)) | ||
|
||
if norm_type == math.inf: | ||
total_norm = max(p.grad.data.abs().max() for p in parameters) | ||
else: | ||
device = parameters[0].device | ||
out = torch.empty(len(parameters), device=device) | ||
for i, p in enumerate(parameters): | ||
torch.norm(p.grad.data.to(device), norm_type, out=out[i]) | ||
total_norm = torch.norm(out, norm_type) | ||
|
||
eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON | ||
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) | ||
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) | ||
for p in parameters: | ||
p.grad.data.mul_(clip_coef.to(p.grad.data.device)) | ||
model = self.trainer.get_model() | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would make sense to have a precision plugin for 32bit for the default case.
Then we wouldn't need this if/else block and could just call
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
just as a refactoring idea for the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
talking about this block (github selected more):
def _clip_gradients(self, optimizer, grad_clip_val):
if self.trainer.amp_backend:
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer)
else:
model = self.trainer.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=2.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree fully, a precision 32 plugin makes the most sense because the accelerators shouldn't need to care about these details (It goes into @justusschock's refactor proposal of making sure the accelerators are hardware only).
Out of the scope of this PR, but definitely something low hanging to reduce liability of issues in the accelerators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to do this in my refactor proposal :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel a bit confused about what is going to be overwritten for Accelerator and plugin, one seems to overwrite in child clip_gradients
the other _clip_gradients
Codecov Report
@@ Coverage Diff @@
## master #4639 +/- ##
=======================================
+ Coverage 90% 93% +3%
=======================================
Files 116 117 +1
Lines 8881 8898 +17
=======================================
+ Hits 8012 8264 +252
+ Misses 869 634 -235 |
…t class (#4639) * Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients * Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex * Fix doc * Applied codereview changes * Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch * Pass correct grad clip val * Moved var to property * Apply code review suggestions (cherry picked from commit bacabae)
…t class (#4639) * Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients * Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex * Fix doc * Applied codereview changes * Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch * Pass correct grad clip val * Moved var to property * Apply code review suggestions (cherry picked from commit bacabae)
…t class (#4639) * Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients * Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex * Fix doc * Applied codereview changes * Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch * Pass correct grad clip val * Moved var to property * Apply code review suggestions
Ties to #4178
What does this PR do?
To ensure we have the appropriate behaviour in clip gradients, I've made it a responsibility of the plugin if it exists in the accelerator. I've also simplified the logic to reflect what should happen in AMP/Apex cases using the torch built in
clip_grad_norm_
function.This is required because sharded accelerator will need a custom clip gradient function :)
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In in short, see following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃