-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Reduce nf4 quant mem #315
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/315
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New FailuresAs of commit 6e672b4 with merge base d75f450 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Can't we just |
@gau-nernst thanks for the suggestion. I decorated |
On my phone right now so I can't check in detail. From the look of it, I think the original implementation return int32 indices, while your updated version return uint8. That might account for the difference in memory usage. Try .to(torch.uint8) in the torch.compile version also. I think torch.compile should be able to generate efficient triton code for this (can check with TORCH_LOGS="output_code" |
@gau-nernst that gets it to comparable memory, thanks! Is it reasonable to change the PR to just add the uint8 cast + torch.compile decorator to this method then? |
@ebsmothers - I'd say "yes" to the uint8 change, but I'm not sure about the compile change, because compile might not work for users on certain platforms (e.g. Windows last time I checked) :/ Maybe we should have a platform guarded wrapper for compile? |
I think this makes sense, the other way I would do this, which would maybe be more performant, is to do this step in chunks similar to this PR: |
@ebsmothers - I'd pick the one that sufficiently reduces memory without introducing performance regressions. So if you can get away with just using compile and not doing chunking, maybe that'll preserve perf and also fix your memory issue. |
Just a comment. We can probably torch.compile |
IMO a big part of this subclass is that it's an eager implementation is very easy to hack on and easy to debug. I am in general not a big fan of baking in torch.compile. All that being said I don't want to thrash, so carry on |
@drisspg - Yes, that's a good point. If we turn it on torch.compile by default, it'll be harder to debug. @gau-nernst - Maybe we'll just add a test that verifies that torch.compile works. So we'll write the code in a way that it's compile-able and results in efficient code, but the user does still have to turn on compile themselves. What do you think? |
Sounds good to me! Just a question. If a function is inside a user-facing API function, but the user can't torch.compile the outer function (e.g. functions that quantize an |
@gau-nernst - Hm, generally |
@cpuhrsch ok, I will look through some of our public quant API and open a separate issue if they can't run through torch.compile. |
@@ -698,7 +699,7 @@ def quantize_tensor_nearest( | |||
value = value.unsqueeze(-1) # (numel, 1) | |||
# Compare the value tensor with the nf4 tensor element-wise | |||
diff = (value - nf4).abs() | |||
closest_nf4 = diff.min(dim=-1).indices | |||
closest_nf4 = diff.min(dim=-1).indices.to(dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On another note, it seems like we could use 4 bits here? We have https://github.com/pytorch/ao/blob/03e2c9b056a17a0e6d5ae31c2a73df36345f886f/torchao/dtypes/uint4.py , but it likely needs some work. I assume closest_nf4
here is an index into the nf4 table which by definition only is 4 bits and thus there aren't more than 4bit worth of indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit-packing logic is done in convert_to_norm_float_weight()
. I don't know the original author's intention but one benefit of not doing bit-packing in quantize_tensor_nearest()
is that we can do correctness test (without unpacking bits).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mhm, I see. So then using compile at a higher level might still avoid materializing the full int8 tensor (thus reducing peak memory utilization during construction).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, from what I remember when I wrote this, the bitshifting is only supported on uint8 dtype:
ao/torchao/dtypes/nf4tensor.py
Line 661 in 03e2c9b
combined_blocks = quantized_blocks[::2] << 4 | quantized_blocks[1::2] |
Also one other thing I noticed is that I think you are using an older version of torch ao that doesn't have my chunking PR: I get this:
|
So I still kind of think all the memory savings is just coming from the previous chunking pr? |
@@ -698,7 +699,7 @@ def quantize_tensor_nearest( | |||
value = value.unsqueeze(-1) # (numel, 1) | |||
# Compare the value tensor with the nf4 tensor element-wise | |||
diff = (value - nf4).abs() | |||
closest_nf4 = diff.min(dim=-1).indices | |||
closest_nf4 = diff.min(dim=-1).indices.to(dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding, diff.min(dim=-1).indices.to(torch.uint8)
: if diff.argmin(dim = -1, dtype = torch.uint8)
is supported, this might get neater
o god.. thank you @drisspg, this is already handled by the PR you landed previously. My torchtune env with ao==0.1.0 was evidently bleeding into my dev env and I did not catch it. I agree your existing PR solves this problem exactly, sorry for the thrash here. I am gonna close this and get torchtune on 0.2.0 asap |
Currently NF4 quantization of large tensors can cause some pretty substantial memory spikes. Consider the following script for quantizing the Llama3-8B output projection weight into NF4:
On main, this currently prints
The peak memory is almost 38 GB, I think due to the fact that
diff = (value - nf4).abs()
is creating a 16x larger tensor to do the pairwise comparison ofvalue
tonf4
tensors.Edit: Updated based on @gau-nernst and @cpuhrsch's suggestions. Instead of a for loop we just
torch.compile
quantize_tensor_nearest
(conditional upon not running on windows) then cast the final indices to uint8. This results in