From a6aefc0048c6eebd6376ab337c5b5ad397a4cebe Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 22 Mar 2024 09:03:58 -0700 Subject: [PATCH] Use `Int8DynActInt4WeightQuantizer` in torchao (#2551) Summary: Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ https://github.com/pytorch/executorch/issues/2551 att bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Pull Request resolved: https://github.com/pytorch/executorch/pull/2551 Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w -X -d fp32 Reviewed By: andrewor14 Differential Revision: D55221981 Pulled By: jerryzh168 fbshipit-source-id: 8f59df461416d5a5bbdc53f3a6baba331ee3f4ce --- examples/models/llama2/export_llama_lib.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 78fc47d95a..4769a80e26 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -38,11 +38,7 @@ from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType -from .quantize import ( - EmbeddingOnlyInt8QuantHandler, - Int8DynActInt4WeightQuantHandler, - WeightOnlyInt8QuantHandler, -) +from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) @@ -241,10 +237,9 @@ def quantize( # Add quantization mode options here: group size, bit width, etc. return WeightOnlyInt8QuantHandler(model).quantized_model() elif qmode == "8da4w": - model = Int8DynActInt4WeightQuantHandler( - model, - precision=torch_dtype, - ).quantized_model() + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + model = Int8DynActInt4WeightQuantizer(precision=torch_dtype).quantize(model) print("quantized model:", model) return model elif qmode == "8da4w-gptq":