Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per Layer Streaming Quantization #655

Closed
msaroufim opened this issue Aug 12, 2024 · 4 comments · Fixed by #699
Closed

Per Layer Streaming Quantization #655

msaroufim opened this issue Aug 12, 2024 · 4 comments · Fixed by #699
Labels
good first issue Good for newcomers

Comments

@msaroufim
Copy link
Member

msaroufim commented Aug 12, 2024

A tradeoff users have often complained, most recently @aredden about is either they

  1. quantize on CPU and then push the model to GPU -> Slow quantization but VRAM efficient
  2. Push to model to GPU and then quantize on GPU -> Fast quantization but needs lots of VRAM

Instead we could have a utility that sends one layer at a time to the gpu, quantizes it and then sends in a new layer synchronously. Granted this workflow seems to interact in a clunky way with torch.compile where we don't compile things layer wise and generally expect the model to be on the device where its compiled

@msaroufim msaroufim added the good first issue Good for newcomers label Aug 12, 2024
@aredden
Copy link

aredden commented Aug 12, 2024

This would be great, since currently if I want to run a large model, and cannot load it onto my device without running into an OOM- I have to individually push weights to device, compile their modules, and then quantize them. Though yeah, it would be great if there were a way to iteratively push, compile and then quantize, since otherwise I miss out on possible optimizations.

@gau-nernst
Copy link
Collaborator

Quantization on CPU can be significantly faster if we use torch.compile(). However, from previous discussions, we don't want to do that as it is harder to debug #315 (comment).

We can explore compiling the function that performs quantization. Does it work? If it works, does it actually save time, due to first compile overhead, as well as potential re-compiles?

@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 13, 2024

Was looking through Llama model code in torchao and came across nn.Module._register_load_state_dict_pre_hook(). I rmb last time I had some ideas utilizing this for memory-efficient quantization too. Something like (haven't tested):

with torch.device("meta"):
    model = ...
    quantize_(model, ...)
model.to_empty(device="cuda")  # materialize quantized model on CUDA

def hook(state_dict, prefix):
    # move original weight to CUDA and quantize it here
    ...

handles = []
for m in model.modules():
    if isinstance(m, nn.Linear):
        handles.append(m._register_load_state_dict_pre_hook(hook))

# weight is quantized on load
model.load_state_dict(state_dict)

# remove hooks
for handle in handles:
    handle.remove()

@gau-nernst
Copy link
Collaborator

On a second thought, it doesn't need to be that complicated.. Because we are already iterating over each module when doing quantization, we can simply add a device argument, and move the matching module (i.e. nn.Linear) to the desired device (i.e. CUDA) before applying the quantization function.

if filter_fn(model, cur_fqn[:-1]):
model = replacement_fn(model)
return model

This means that the full original model is in CPU RAM. The convoluted solution in my previous reply can be used with mmap to further reduce RAM usage (i.e. don't materialize full model in CPU RAM).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants