-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Per Layer Streaming Quantization #655
Comments
This would be great, since currently if I want to run a large model, and cannot load it onto my device without running into an OOM- I have to individually push weights to device, compile their modules, and then quantize them. Though yeah, it would be great if there were a way to iteratively push, compile and then quantize, since otherwise I miss out on possible optimizations. |
Quantization on CPU can be significantly faster if we use We can explore compiling the function that performs quantization. Does it work? If it works, does it actually save time, due to first compile overhead, as well as potential re-compiles? |
Was looking through Llama model code in torchao and came across with torch.device("meta"):
model = ...
quantize_(model, ...)
model.to_empty(device="cuda") # materialize quantized model on CUDA
def hook(state_dict, prefix):
# move original weight to CUDA and quantize it here
...
handles = []
for m in model.modules():
if isinstance(m, nn.Linear):
handles.append(m._register_load_state_dict_pre_hook(hook))
# weight is quantized on load
model.load_state_dict(state_dict)
# remove hooks
for handle in handles:
handle.remove() |
On a second thought, it doesn't need to be that complicated.. Because we are already iterating over each module when doing quantization, we can simply add a ao/torchao/quantization/quant_api.py Lines 171 to 173 in e7fc0ed
This means that the full original model is in CPU RAM. The convoluted solution in my previous reply can be used with mmap to further reduce RAM usage (i.e. don't materialize full model in CPU RAM). |
A tradeoff users have often complained, most recently @aredden about is either they
Instead we could have a utility that sends one layer at a time to the gpu, quantizes it and then sends in a new layer synchronously. Granted this workflow seems to interact in a clunky way with
torch.compile
where we don't compile things layer wise and generally expect the model to be on the device where its compiledThe text was updated successfully, but these errors were encountered: