Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to move param to device before quantization #699

Merged
merged 6 commits into from
Aug 19, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Aug 17, 2024

Fixes #655

Test

import torch
from torchao._models.llama.model import Transformer, ModelArgs
from torchao.quantization.quant_api import quantize_, int8_weight_only
import time
import gc

def get_model():
    with torch.device("meta"):
        model = Transformer(ModelArgs())
    model.to_empty(device="cpu").bfloat16()
    return model

def clear_cuda():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    assert torch.cuda.max_memory_reserved() == 0

model = get_model()
time0 = time.perf_counter()
quantize_(model, int8_weight_only())
model.to("cuda")
torch.cuda.synchronize()
duration = time.perf_counter() - time0
print("Quantize on CPU:")
print(f"  - Time taken: {duration:.2f} s")
print(f"  - Peak memory: {torch.cuda.max_memory_reserved() / 1e9:.2f} GiB")
del model
clear_cuda()

model = get_model()
time0 = time.perf_counter()
model.to("cuda")
quantize_(model, int8_weight_only())
torch.cuda.synchronize()
duration = time.perf_counter() - time0
print("Quantize on CUDA:")
print(f"  - Time taken: {duration:.2f} s")
print(f"  - Peak memory: {torch.cuda.max_memory_reserved() / 1e9:.2f} GiB")
del model
clear_cuda()

model = get_model()
time0 = time.perf_counter()
quantize_(model, int8_weight_only(), device="cuda")
torch.cuda.synchronize()
duration = time.perf_counter() - time0
print("Move to CUDA and quantize each param individually:")
print(f"  - Time taken: {duration:.2f} s")
print(f"  - Peak memory: {torch.cuda.max_memory_reserved() / 1e9:.2f} GiB")
del model
clear_cuda()
Quantize on CPU:
  - Time taken: 10.48 s
  - Peak memory: 6.99 GiB
Quantize on CUDA:
  - Time taken: 1.96 s
  - Peak memory: 14.50 GiB
Move to CUDA and quantize each param individually:
  - Time taken: 1.94 s
  - Peak memory: 8.29 GiB

@jerryzh168 Do you have any suggestions what tests I should add for this?

This PR may help the slow NF4 quantization issue observed in torchtune too #642 (though I'm not sure what is the NF4 quantization API. Does it use quantize_()?)

Copy link

pytorch-bot bot commented Aug 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/699

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 711d001 with merge base b523f9f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 17, 2024
@jerryzh168
Copy link
Contributor

nf4 is not yet using quantize_ API, it is module swapping linear with a LoraLinear:

LoRALinear(
, maybe we could incorporate this in quantize_ as well if there is a weight only quantization use case for nf4 as well. cc @drisspg

the change makes sense, one thing I'm wondering is would it help further reduce cuda memory usage if you move the module back to cpu (original device) after quantization as well (and release the used cuda memory)? not sure if it would be slow though.

for tests, I think you'll probably need to construct a model with linear + non-linear modules and compare the max cuda memory for "Quantize on CUDA" and "Move to CUDA and quantize each param individually" and make sure the second one consumes less memory. you can add it in https://github.com/pytorch/ao/blob/main/test/quantization/test_quant_api.py I think

@gau-nernst
Copy link
Collaborator Author

nf4 is not yet using quantize_ API, it is module swapping linear with a LoraLinear

I see. So users of NF4 need to implement this logic i.e. torchtune will have to implement this instead of torchao.

one thing I'm wondering is would it help further reduce cuda memory usage if you move the module back to cpu (original device) after quantization as well (and release the used cuda memory)? not sure if it would be slow though.

I can try this. But in the end, I think we should stick to 1 approach to avoid confusion to users.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Aug 17, 2024

I see. So users of NF4 need to implement this logic i.e. torchtune will have to implement this instead of torchao.

yeah according to Driss, this LoraLinear code is specific to the framework that uses nf4 (e.g. torchtune, huggingface, vllm etc.) so we don't want to offer an API for people

@msaroufim
Copy link
Member

for tests, I think you'll probably need to construct a model with linear + non-linear modules and compare the max cuda memory for "Quantize on CUDA" and "Move to CUDA and quantize each param individually" and make sure the second one consumes less memory. you can add it in main/test/quantization/test_quant_api.py I think

Agreed this is the right test

the change makes sense, one thing I'm wondering is would it help further reduce cuda memory usage if you move the module back to cpu (original device) after quantization as well (and release the used cuda memory)? not sure if it would be slow though.

This scenario seems to optimize for people that have a GPU but want the model to run on the CPU, that seems rarer than the alternative of someone who wants to fit a model that in full precision does not fit on their GPU

Otherwise yeah once CI is green can merge this

@gau-nernst
Copy link
Collaborator Author

This scenario seems to optimize for people that have a GPU but want the model to run on the CPU, that seems rarer than the alternative of someone who wants to fit a model that in full precision does not fit on their GPU

The difference in peak memory between this PR and quantize on CPU (8.29 GiB vs 6.99 GiB) is probably due to the largest nn.Linear layer (LM head). I think what @jerryzh168 means is that, we can further eliminate this difference by moving quantized weight back to CPU one-by-one, so the peak VRAM after quantization will be the largest nn.Linear layer. When we move the whole CPU quantized model back to CUDA again, the peak memory should be the same as quantize on CPU (6.99 GiB), if memory fragmentation doesn't haunt us. But yea it's a bit more convoluted.

@gau-nernst
Copy link
Collaborator Author

Using Llama2-7B

Baseline CPU: Quantize on CPU, then move the whole model to CUDA
  - Time taken: 10.48 s
  - Peak memory: 6.99 GiB

Baseline CUDA: Move the whole model to CUDA, then quantize on CUDA
  - Time taken: 1.96 s
  - Peak memory: 14.50 GiB

Approach 1: For each param, move to CUDA, quantize on CUDA
  - Time taken: 1.94 s
  - Peak memory: 8.29 GiB

Approach 2: For each param, move to CUDA, quantize on CUDA, move back to CPU. Then move the whole model to CUDA
  - Time taken: 5.58 s
  - Peak memory: 7.03 GiB

There are tradeoffs between approach 1 and 2. Lmk which one you prefer. I think we should only keep one approach to avoid confusion.

Another thing we can consider is to measure the size of the weight to decide what to do with it. E.g. if the weight exceeds xx size (which hopefully only the LM head is), we will always quantize on CPU -> avoid the memory spike. Again, this also adds seemingly unnecessary complexity.

@msaroufim
Copy link
Member

I like approach 1 better because it's a better tradeoff between the options we already have

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for benchmarking different approaches @gau-nernst , yeah Approach 1 sounds good to me as well, it is simpler and easier to explain in API as well

@msaroufim msaroufim merged commit 477ddb6 into pytorch:main Aug 19, 2024
16 checks passed
@gau-nernst gau-nernst deleted the quantize_per_param branch August 19, 2024 23:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Per Layer Streaming Quantization
4 participants