Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor int4 and int8 weight only quantization to use quantize #301

Merged
merged 7 commits into from
Jun 4, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Jun 1, 2024

Summary:
This is similar to #294 but applied for int4 weight only quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793

integration perf test:

reference: elapsed_time: 2.5900126953125 milliseconds
after refactor: elapsed_time: 2.56680078125 milliseconds
diff: no diff

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Before:
After:
generated code diff:

Reviewers:

Subscribers:

Tasks:

Tags:

Refactor int8 weight only quant to use quantize #299 logs

Summary:
Similar to #294 we replaced the implementation
of int8 weight only quant to used the newly added quantize function, as a part of
the unification effort for affine quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf
elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368

integration test:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Reference: elapsed_time: 1.355208740234375 milliseconds
After refactor: elapsed_time: 1.32778857421875 milliseconds

code diff (gist): gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): internalfb.com/phabricator/paste/view/P1387333845

…antize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:
Similar to pytorch#294 we replaced the implementation
of int8 weight only quant to used the newly added `quantize` function, as a part of
the unification effort for affine quantization

Test Plan:
1. unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf

elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368

2. integration test:

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Reference: elapsed_time:  1.355208740234375  milliseconds
After refactor: elapsed_time:  1.32778857421875  milliseconds

code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845

Reviewers:

Subscribers:

Tasks:

Tags:
…antize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Jun 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/301

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 74ecb09 with merge base 729fa4d (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 1, 2024
@jerryzh168 jerryzh168 changed the title Replace implementation for int8 dynamic quantization with call to `qu… Refactor int4 weight only quantization with call to quantize Jun 1, 2024
@jerryzh168
Copy link
Contributor Author

this is rebased on int8-wo PR (#299) so will need to update this PR after the int8-wo PR is landed

@jerryzh168 jerryzh168 changed the title Refactor int4 weight only quantization with call to quantize Refactor int4 weight only quantization to use quantize Jun 1, 2024
@@ -930,6 +930,7 @@ def _test_lin_weight_subclass_impl(
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an issue or a short description of both bugs we can add, otherwise will be hard to remember when to remove the skipIf

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jun 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just a inductor c++ compilation bug I think, I'm planning to open a PR after this, I have opened one for the other skip here: #300

test/quantization/test_quant_api.py Outdated Show resolved Hide resolved
test/quantization/test_quant_api.py Outdated Show resolved Hide resolved
return layout_cls
return decorator

def get_aqt_layout_cls(extended_layout: str) -> Callable:
def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does ctr stand for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means constructor, since we are returning class.from_plain now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment I don't believe ctr is a common abbreviation for constructor

# int_data = int_data.view(shape)
# changed = self.from_plain(int_data, scale, zero)
# return changed
# TODO: changing shape is no-op for int4 packed weight right now
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you share some more detail on this I'm quite curious

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'm confirming with @HDCharles right now, I think this is pretty weird, see comments in L575 of aqt.py for more details


@classmethod
def from_plain(cls, int_data, scale, zero_point):
# TODO: expose the arg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just do it now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one needs a bit more discussions with pt core team

if extended_layout == "tensor_core_tiled":
from torchao.quantization.utils import find_multiple
orig_out_features, orig_in_features = input_float.shape
in_features = find_multiple(orig_in_features, 1024)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do the constants for 1024 and 8 come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is specific to tinygemm kernels I think, copied from old code:

in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)

torchao/dtypes/aqt.py Outdated Show resolved Hide resolved
torchao.apply_dynamic_quant(model)
from torch._inductor import config as inductorconfig
inductorconfig.force_fuse_int_mm_with_mul = True
# int8 act, int8 weight dynamic quantization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we delete code here instead of commenting it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, this is just for people to easily try out different APIs, but we can just ask people to copy paste from README as well

# groupwise int4 quantization
groupsize = weight_qtensor.block_size[-1]
if not _from_flinear:
weight_qtensor = weight_qtensor.t()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: why does this require a transpose?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jun 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to align the dimensions, for block_size so that we can get groupsize from block_size argument, see L662, and also related to L575. right now the _quantized_linear does not have a well-defined accepted weight shape, we need to fix that

@@ -507,7 +571,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

def _quantized_linear_op(input_tensor, weight_qtensor, bias):
def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True):
# TODO: the old tensor subclass can use the single implementation for both F.linear dispatch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim see this comment for more details

return layout_cls
return decorator

def get_aqt_layout_cls(extended_layout: str) -> Callable:
def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment I don't believe ctr is a common abbreviation for constructor

filter_fn,
)
if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int4wo_quant(**kwargs), filter_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blind kwargs make it impossible to document the behavior. i understand that change_linear_weights_to_int4_woqtensors has this as well. Seems like something that could be worth fixing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure

@@ -55,3 +58,10 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
)
measurement = t0.blocked_autorange()
return measurement.mean * 1e6


def find_multiple(n: int, *args: Tuple[int]) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we now use this in torchao/dtypes and torchao/quantization and have to do import tricks to avoid circular dep

Summary:
This is similar to pytorch#294 but applied for int4 weight only quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793

integration perf test:

reference: elapsed_time:  2.5900126953125  milliseconds
after refactor: elapsed_time:  2.56680078125  milliseconds
diff: no diff

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Before:
After:
generated code diff:

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 changed the title Refactor int4 weight only quantization to use quantize Refactor int4 and int8 weight only quantization to use quantize Jun 4, 2024
@jerryzh168 jerryzh168 merged commit 338d87c into pytorch:main Jun 4, 2024
12 of 13 checks passed
@jerryzh168 jerryzh168 deleted the int4-wo branch June 4, 2024 17:35
@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 4, 2024

Please don't merge PRs when CI is red and we can't get signal for incremental changes. Fix main CI first, then merge.

@jerryzh168
Copy link
Contributor Author

Please don't merge PRs when CI is red and we can't get signal for incremental changes. Fix main CI first, then merge.

makes sense, sorry about this, will do next time

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
…torch#301)

* Replace implementation for int8 dynamic quantization with call to `quantize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:

* Refactor int8 weight only quant to use `quantize`

Summary:
Similar to pytorch#294 we replaced the implementation
of int8 weight only quant to used the newly added `quantize` function, as a part of
the unification effort for affine quantization

Test Plan:
1. unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf

elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368

2. integration test:

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Reference: elapsed_time:  1.355208740234375  milliseconds
After refactor: elapsed_time:  1.32778857421875  milliseconds

code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845

Reviewers:

Subscribers:

Tasks:

Tags:

* Replace implementation for int8 dynamic quantization with call to `quantize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:

* Refactor int4 weight only quantization with call to `quantize`

Summary:
This is similar to pytorch#294 but applied for int4 weight only quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793

integration perf test:

reference: elapsed_time:  2.5900126953125  milliseconds
after refactor: elapsed_time:  2.56680078125  milliseconds
diff: no diff

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Before:
After:
generated code diff:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants