Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codebook (look up table based) quantization flow in torchao #1195

Open
1 of 3 tasks
jerryzh168 opened this issue Oct 29, 2024 · 18 comments
Open
1 of 3 tasks

Add codebook (look up table based) quantization flow in torchao #1195

jerryzh168 opened this issue Oct 29, 2024 · 18 comments
Assignees
Labels
good first issue Good for newcomers

Comments

@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 29, 2024

Similar to affine quantization, we can implement codebook or look up table based quantization, which is another popular type of quantization, especially for lower bits like 4 bits or below (used in https://github.com/Vahe1994/AQLM, https://arxiv.org/abs/2402.04396 etc.). We can start with post training quantization and use k-means clustering to find the codebook / lookup table. You can check out #391 for the overall structure of torchao stack. Reference code for k-means can be found here.

After this we can also add more support for the advanced algorithms mentioned above.

API

quantize_(model, codebook_weight_only(dtype=torch.uint4))

Implementation details:

  • [PR1] Ops
    • quantize_codebook(tensor, codebook)
    • dequantize_codebook(tensor, codebook)
  • [PR2] Tensor Subclass
    • CodebookQuantizedTensor (similar to AffineQuantizedTensor)
      • clustering algorithm can be implemented in from_float function

Needs to flesh out the details of args etc. but can be done in the PR. I'd suggest to gradually add things and gather feedback.

Code Location: add a codebook folder under https://github.com/pytorch/ao/tree/main/torchao/prototype/quantization

Tasks

Preview Give feedback
@malinjawi
Copy link

Hey @jerryzh168 I am new to torchao but this sounds like an issue I would want to investigate with my partner @Harthi7. We will take a look and let you know how it goes. Cheers!

@DerekLiu35
Copy link
Contributor

Hi, I am also new to torchao and I would like to do this issue?

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
…rch#1195)

* update flamingo model for tune

* 1/n flamingo e2e ppl

* flamingo e2e enable

* bump up tune version

* remove hacky cache size, add comment for magic number

* dytpe set for input

* manually cast dtype

* extra config for deep fusion module
@pawarmanasi07
Copy link

pawarmanasi07 commented Dec 30, 2024

Hi! I'm interested in contributing to the implementation of the codebook quantization. Would it be helpful if I worked on [e.g., adding test cases]? Happy to coordinate with @DerekLiu35 to avoid duplicating effort.

@DerekLiu35
Copy link
Contributor

DerekLiu35 commented Dec 30, 2024

I'd also be happy to coordinate.
I think the main thing to do is add AQLM support (the tuning part, though I'm not sure why it would be beneficial to have the tuning in torchao, compared to just using AQLM repo and then converting it to torchao representation) and making token generation faster (probably by copying dequantization kernels from AQLM)

@jerryzh168
Copy link
Contributor Author

I think the two immediate things are adding AQLM support and speedup. Adding AQLM in torchao will be a bit more convenient for users compared to using AQLM repo and then convert I think

@pawarmanasi07
Copy link

Great! Let me know what I can start with.

@DerekLiu35
Copy link
Contributor

I'll focus on speeding up token generation, can coordinate more if @pawarmanasi07 also wants to work on that.

@pawarmanasi07
Copy link

I can help with that!

@pawarmanasi07
Copy link

@DerekLiu35 Could you share your thoughts on which aspects of the dequantization kernels from AQLM we should focus on first? We could divide up different parts of the optimization work between us?

@DerekLiu35
Copy link
Contributor

I think we can focus on 1x16 group size cuda kernels and triton (as fallback). we could divide optimization work by one of us focusing on forward pass kernels and the other on backward pass kernels, though I'm not sure why you need backward pass kernels. we could also split by different kernels like 1x16 group size and 1x1 group size (no reference cuda kernels in AQLM). I'm not sure what the best way to divide work between us though. I'll probably start with 1x16 forward pass kernel

@pawarmanasi07
Copy link

Sounds good! I think focusing on the 1x16 group size kernels makes sense as a starting point. I can work on the 1x1 group size kernels while you tackle the 1x16 forward pass implementation.

For the backward pass kernels - you raise a good point about whether they're necessary. Since this is post-training quantization, we likely don't need backward pass optimization unless we're planning to support fine-tuning scenarios?

@pawarmanasi07
Copy link

pawarmanasi07 commented Jan 1, 2025

Hi @DerekLiu35 and @jerryzh168, to confirm my tasks - I'll be focusing on optimizing the dequantization for 1x1 group size.
This involves:

  • Creating an optimized kernel for 1x1 dequantization

  • Adding benchmarks and tests

  • Integrating with the existing codebase

While Derek focuses on the 1x16 forward pass kernel implementation.
Since there are no reference CUDA kernels in AQLM for 1x1, should I:

Implement new CUDA kernels for 1x1
Use Triton for 1x1
Or implement both approaches?

Is this the correct understanding of the work division? I just want to ensure I'm heading in the right direction before starting.

@pawarmanasi07
Copy link

However would it make more sense to start with Triton implementation for 1x1 first (since we need it as a fallback anyway)
then evaluate if we need CUDA implementation based on performance?

@DerekLiu35
Copy link
Contributor

DerekLiu35 commented Jan 1, 2025

Yeah I think that would make sense to start with triton fallback first

@LucasHaug
Copy link

LucasHaug commented Jan 10, 2025

Not sure if this is the right place to post this, but I was testing the CodebookQuantizedTensor and noticed a problem. I can't say if it's easy to fix or if it's a priority, but I thought it would be better to report it.

I was trying to set the the requires_grad of the tensor to True but as the codes are uint, they can't require gradients.

For example:

import torch
from torchao.prototype.quantization.codebook.codebook_quantized_tensor import CodebookQuantizedTensor

input_tensor = torch.randn(1024, 1024,  device='cuda')

block_size = (1, 1)
code_dtype = torch.uint4

quantized_tensor = CodebookQuantizedTensor.from_float(input_tensor, block_size, code_dtype, scale_block_size=32)

quantized_tensor.requires_grad_(True)

Results in:

Traceback (most recent call last):
  File "/home/haug/test.py", line 11, in <module>
    quantized_tensor.requires_grad_(True)
  File "/home/haug/ao/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py", line 225, in requires_grad_
    self.codes.requires_grad_(requires_grad)
  File "/home/haug/ao/torchao/utils.py", line 436, in _dispatch__torch_function__
    return func(*args, **kwargs)
RuntimeError: only Tensors of floating point dtype can require gradients

@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Jan 10, 2025

thanks for trying this out @LucasHaug, yeah codebook quant is inference only right now, we should definitely support training as well, it will be supported when we add AQLM I think.

@DerekLiu35 @pawarmanasi07 any of you also plans to work on supporting AQLM flow?

For AQLM I think we can use module swap API for now, example:

def convert_to_float8_training(
, please let me know if you need more detailed guidelines

concrete official training guideline is still in discussion. cc @vkuzo @andrewor14

@pawarmanasi07
Copy link

Yeah, I plan on working on it after I finish the current task.

@mostafaelhoushi
Copy link

I have been in discussion with @jcaip about integrating a 1x1 codebook quantization algorithm (for a paper that sould be on arxiv soon) as well as kernel into torchao and just noticed the discussions in this PR:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

6 participants