Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Reduce nf4 quant mem #315

Closed
wants to merge 4 commits into from

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Jun 4, 2024

Currently NF4 quantization of large tensors can cause some pretty substantial memory spikes. Consider the following script for quantizing the Llama3-8B output projection weight into NF4:

from torch import nn
from torchtune.utils import get_memory_stats, get_device
from torchao.dtypes.nf4tensor import to_nf4

def main():
    
    device = get_device('cuda')

    # Size of Llama3-8B output projection weight
    big_linear = nn.Linear(in_features=4096, out_features=128256, bias=False, device=device)
    memory_stats = get_memory_stats(device=device)
    print(f"before quantize: {memory_stats}")
    
    # Quantize with ao
    ao_quant = to_nf4(big_linear.weight)
    memory_stats = get_memory_stats(device=device)
    print(f"after ao quant: {memory_stats}")



if __name__ == "__main__":
    main()

On main, this currently prints

before quantize: {'peak_memory_active': 2.101346304, 'peak_memory_alloc': 2.101346304, 'peak_memory_reserved': 2.101346304}
after ao quant: {'peak_memory_active': 37.865276416, 'peak_memory_alloc': 37.865276416, 'peak_memory_reserved': 37.977325568}

The peak memory is almost 38 GB, I think due to the fact that diff = (value - nf4).abs() is creating a 16x larger tensor to do the pairwise comparison of value to nf4 tensors.

Edit: Updated based on @gau-nernst and @cpuhrsch's suggestions. Instead of a for loop we just torch.compile quantize_tensor_nearest (conditional upon not running on windows) then cast the final indices to uint8. This results in

before quantize: {'peak_memory_active': 2.101346304, 'peak_memory_alloc': 2.101346304, 'peak_memory_reserved': 2.101346304}
after ao quant: {'peak_memory_active': 5.29545728, 'peak_memory_alloc': 5.29545728, 'peak_memory_reserved': 5.668601856}

Copy link

pytorch-bot bot commented Jun 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/315

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit 6e672b4 with merge base d75f450 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 4, 2024
@msaroufim msaroufim requested a review from drisspg June 4, 2024 02:54
@gau-nernst
Copy link
Collaborator

Can't we just torch.compile() the function quantize_tensor_nearest()?

@ebsmothers
Copy link
Contributor Author

@gau-nernst thanks for the suggestion. I decorated quantize_tensor_nearest with @torch.compile and still see a peak allocated memory around 12.7GB. Any ideas on how I could reduce it further?

@gau-nernst
Copy link
Collaborator

On my phone right now so I can't check in detail. From the look of it, I think the original implementation return int32 indices, while your updated version return uint8. That might account for the difference in memory usage. Try .to(torch.uint8) in the torch.compile version also. I think torch.compile should be able to generate efficient triton code for this (can check with TORCH_LOGS="output_code"

@ebsmothers
Copy link
Contributor Author

@gau-nernst that gets it to comparable memory, thanks! Is it reasonable to change the PR to just add the uint8 cast + torch.compile decorator to this method then?

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 4, 2024

@ebsmothers - I'd say "yes" to the uint8 change, but I'm not sure about the compile change, because compile might not work for users on certain platforms (e.g. Windows last time I checked) :/ Maybe we should have a platform guarded wrapper for compile?

@drisspg
Copy link
Contributor

drisspg commented Jun 4, 2024

I think this makes sense, the other way I would do this, which would maybe be more performant, is to do this step in chunks similar to this PR:

#196

@ebsmothers
Copy link
Contributor Author

Thanks @cpuhrsch and @drisspg for the guidance here. I'm kinda indifferent between the two approaches (compile + explicit cast vs chunking), would you guys recommend one over the other?

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 4, 2024

@ebsmothers - I'd pick the one that sufficiently reduces memory without introducing performance regressions. So if you can get away with just using compile and not doing chunking, maybe that'll preserve perf and also fix your memory issue.

@gau-nernst
Copy link
Collaborator

Just a comment. We can probably torch.compile convert_to_norm_float_weight() also, which further reduce memory usage (if users don't call quantize_tensor_nearest() directly). (and chunking logic can be removed, unless you are on Windows 😆)

@drisspg
Copy link
Contributor

drisspg commented Jun 5, 2024

IMO a big part of this subclass is that it's an eager implementation is very easy to hack on and easy to debug. I am in general not a big fan of baking in torch.compile. All that being said I don't want to thrash, so carry on

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 5, 2024

@drisspg - Yes, that's a good point. If we turn it on torch.compile by default, it'll be harder to debug.

@gau-nernst - Maybe we'll just add a test that verifies that torch.compile works. So we'll write the code in a way that it's compile-able and results in efficient code, but the user does still have to turn on compile themselves. What do you think?

@gau-nernst
Copy link
Collaborator

Sounds good to me! Just a question. If a function is inside a user-facing API function, but the user can't torch.compile the outer function (e.g. functions that quantize an nn.Module - generally they can't be compiled right? just checked a few of them under torchao.quantization), they probably can't use the "compiled" version of the inner function? (Unless there is an extra flag to the outer function, but that doesn't sound ideal)
This might not be a problem with this NF4, since I'm not familiar with the code here.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 5, 2024

@gau-nernst - Hm, generally torch.compile should work on most PyTorch code, albeit it might cause graph breaks (they can be detected with TORCH_LOGS='graph_breaks'), which could impact performance. If you tried to compile something and it didn't work, I think it'd be worth for us to see if we can fix it. In any case, it's also ok to have multiple calls to torch.compile within a process if there's code that prevents us from using it (but really if there is such code, we should likely figure out why for best overall performance).

@gau-nernst
Copy link
Collaborator

@cpuhrsch ok, I will look through some of our public quant API and open a separate issue if they can't run through torch.compile.

@@ -698,7 +699,7 @@ def quantize_tensor_nearest(
value = value.unsqueeze(-1) # (numel, 1)
# Compare the value tensor with the nf4 tensor element-wise
diff = (value - nf4).abs()
closest_nf4 = diff.min(dim=-1).indices
closest_nf4 = diff.min(dim=-1).indices.to(dtype=torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On another note, it seems like we could use 4 bits here? We have https://github.com/pytorch/ao/blob/03e2c9b056a17a0e6d5ae31c2a73df36345f886f/torchao/dtypes/uint4.py , but it likely needs some work. I assume closest_nf4 here is an index into the nf4 table which by definition only is 4 bits and thus there aren't more than 4bit worth of indices?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit-packing logic is done in convert_to_norm_float_weight(). I don't know the original author's intention but one benefit of not doing bit-packing in quantize_tensor_nearest() is that we can do correctness test (without unpacking bits).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhm, I see. So then using compile at a higher level might still avoid materializing the full int8 tensor (thus reducing peak memory utilization during construction).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, from what I remember when I wrote this, the bitshifting is only supported on uint8 dtype:

combined_blocks = quantized_blocks[::2] << 4 | quantized_blocks[1::2]

@drisspg
Copy link
Contributor

drisspg commented Jun 6, 2024

Also one other thing I noticed is that I think you are using an older version of torch ao that doesn't have my chunking PR:

I get this:

before quantize: {'peak_memory_active': 2.101346304, 'peak_memory_alloc': 2.101346304, 'peak_memory_reserved': 2.101346304}
after ao quant: {'peak_memory_active': 5.295585792, 'peak_memory_alloc': 5.295585792, 'peak_memory_reserved': 5.49453824}

@drisspg
Copy link
Contributor

drisspg commented Jun 6, 2024

So I still kind of think all the memory savings is just coming from the previous chunking pr?

@@ -698,7 +699,7 @@ def quantize_tensor_nearest(
value = value.unsqueeze(-1) # (numel, 1)
# Compare the value tensor with the nf4 tensor element-wise
diff = (value - nf4).abs()
closest_nf4 = diff.min(dim=-1).indices
closest_nf4 = diff.min(dim=-1).indices.to(dtype=torch.uint8)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding, diff.min(dim=-1).indices.to(torch.uint8): if diff.argmin(dim = -1, dtype = torch.uint8) is supported, this might get neater

@ebsmothers
Copy link
Contributor Author

So I still kind of think all the memory savings is just coming from the previous chunking pr?

o god.. thank you @drisspg, this is already handled by the PR you landed previously. My torchtune env with ao==0.1.0 was evidently bleeding into my dev env and I did not catch it. I agree your existing PR solves this problem exactly, sorry for the thrash here. I am gonna close this and get torchtune on 0.2.0 asap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants