Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing Issue #5241: Updating deepspeed/runtime/zero/stage_1_and_2.py #5252

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,8 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = total_norm_cuda[0].item()**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
logger.info(f"Warning: invalid gradient detected. Please check your model implementation/configuration to improve the numerical stability.")
total_norm = -1.
Comment on lines +1310 to +1311
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the following work? Based on the non-offload code path.

Suggested change
logger.info(f"Warning: invalid gradient detected. Please check your model implementation/configuration to improve the numerical stability.")
total_norm = -1.
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)

Copy link
Author

@desire2020 desire2020 Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea here is to warn the user that this is an unexpected behavior and ususally it would cause/is caused by an error in the gradients. I'm fine with the current temporary fix, but I still hope we can probably throw an exception here or print some warning information to let the user know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@desire2020, I was making two points

  1. It is better for the offload and non-offload code paths to have similar behavior and appearance to users. In other words, total_norm should be tensors in both paths, and same logging behavior.
  2. Invalid gradient norms would later trigger overflow detection for that iteration. And DS and most frameworks already handle that correctly, with appropriate warning messages. Can you please check that there is no subsequent overflow message warning?

What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tjruwase , yes I checked there is no subsequent overflow message warning. The code immediately crashed, after returning -1 here and then when calculating the grad norm elsewhere, PyTorch throws an uncaught exception when we are trying to calculate the norm of an (autocasted) torch.int64. I was using deepspeed directly with the official huggingface lm example code with mixed precision of bf16, didn't change too much of it. For the error message, please refer to Issue #5241. It's the same in my case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@desire2020, thanks for the response. I think you have stumbled on a bigger problem in our code. Normally, we check for overflows before [computing gradient norms]. We assume that the conditions for an overflow are the same as for gradient norm of -1, and so norm computation is skipped on overflows. For bf16 training, we don't check for overflows because we assume overflows are impossible. But it seems your test case contradicts our assumption. Can you try enabling overflow checks for bf16 training to see if overflow is detected? Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tjruwase , sorry if I'm asking a stupid question, is this a currently supported feature or we'd be expecting it in the next release of deepspeed? 'Cause I don't find related implementation of this feature in the current version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@desire2020, no worries, I was not clear. I was asking if you could add this feature to your PR? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More specifically, the desired feature is to add "check_overflow" option into the "bf16" dict of the ds_config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase Ok, do you think I should add it for fp16 at the same time? If people are certain that their objective is stable and simple, they can use this option to further boost their fp16 mixed precision training too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but let's make the fp16 option true by default to preserve BC. Thanks!


return total_norm

Expand Down
Loading