Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Add grad norm monitoring/logging #1407

Closed
gau-nernst opened this issue Aug 25, 2024 · 6 comments
Closed

[Feature request] Add grad norm monitoring/logging #1407

gau-nernst opened this issue Aug 25, 2024 · 6 comments

Comments

@gau-nernst
Copy link
Contributor

Personally I have found that monitoring grad norm is useful to understand stability of training. It is also useful to set an appropriate clipping value (though I don't think torchtune supports grad norm clipping atm?).

Some considerations:

@ebsmothers
Copy link
Contributor

@gau-nernst thanks for this suggestion. Actually we had a similar discussion a few months back in #897, maybe I was too much of a stickler about it at that time 😅. In addition to the comments I left there, I agree with your point on not slowing down training. I think your proposal to only calculate at logging step is reasonable, but in practice I think many of our configs set log_every_n_steps=1 so in that case it doesn't do much for us anyways.

One alternative to using nn.utils.clip_grad_norm is to provide a bit more flexibility in the metrics we log in general. There is a world where we allow passing a list of callables that return any custom metrics that a user may want to compute and update the log dict accordingly somewhere around here. The main question is how we actually design this, but I can imagine the signature of each callable being optional loss, optional model, and optional optimizer, which should be reasonably flexible (I don't wanna get into per-batch vs per-step vs per-epoch logging though cause that's a whole other can of worms). The downside of this is that (a) we don't actually clip the grad norm (though technically I guess the user can do it themselves by calling that API) and (b) it might be a bit overengineered.

Otherwise there is the nn.utils.clip_grad_norm route. My main question here is around how we enable optional grad norm logging (so we don't have to slow down training) and/or clipping nicely from the config. Maybe something like clip_grad_norm: Optional[float] and log_grad_norm: bool?

Personally I am open to either of these approaches, would be interested to hear your thoughts on the pros and cons here as well.

Agree that doing this properly for FSDP will need a bit more thought (I assume we would want the norm across all ranks and not per-rank? Also I believe if we use ). But fine to punt it for now.

@gau-nernst
Copy link
Contributor Author

From the discussion in #897, I agree with you that we should not enable gradient clipping by default. It should be set explicitly by the user (burned many times when default hparams are different across HF models 🌚)

Default log_every_n_steps=1 is not a big problem I think. If a user cares about perf, I think they will set this to a larger value (perhaps we should benchmark this some time too! What is the impact of logging every step). Keeping it default to 1 is fine, it's useful for debugging (logging show up early).

In terms of benchmark, we can also check how much calculating grad norm every step is gonna cost us. Maybe it's not that much? 🤔

provide a bit more flexibility in the metrics we log in general

I think this is nice on paper but a nightmare to design and maintain 😢. Realistically, I'm not sure if there are that many other useful metrics to log (except from task-specific things, which should be hard-coded in their own respective recipes already).

For now, I think it is reasonable to have clip_grad_norm: Optional[float].

  • If it is None: do nothing
  • If it is a float value: clip grad norm -> got free grad norm value -> log it

So now the question is what to do when clip_grad_norm=None. To summarize the options we have so far:

  • Also log grad norm, always. Only calculate it on logging step. Should benchmark the impact of this to make sure.
  • Expose optional flag log_grad_norm: bool, which lets user to control whether to log grad norm.

@ebsmothers
Copy link
Contributor

@gau-nernst I think your proposal makes sense.

So now the question is what to do when clip_grad_norm=None. To summarize the options we have so far:

  • Also log grad norm, always. Only calculate it on logging step. Should benchmark the impact of this to make sure.
  • Expose optional flag log_grad_norm: bool, which lets user to control whether to log grad norm.

In the absence of any data, I would lean towards the second option just to be safe. However, if we do find that the perf impact of logging grad norm is negligible, the first option would be fine too (and simpler). For benchmarking purposes we may want to look at distributed too since inevitably we will want to add it at some point and the clip_grad_norm call is likely to be more expensive in that case.

@ebsmothers
Copy link
Contributor

Hey @gau-nernst are you working on this one already? If not we may have someone who can help out here

@gau-nernst
Copy link
Contributor Author

I'm not working on this. You can assign this to someone else.

@gau-nernst
Copy link
Contributor Author

This is not implemented for distributed recipes yet right? So maybe can keep this issue open.

I was adding this feature to my codebase w/ FSDP2, thought it might be useful for torchtune too.

So it's pretty straight-forward to support this feature with FSDP2.

Another separate issue. I think all metrics logged by torchtune is "local" metric e.g. loss value is the loss on rank 0 only. To get "accurate" loss value, need to do all-reduce. Might not be so important...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants