Skip to content

Commit

Permalink
Use Int8DynActInt4WeightQuantizer in torchao (#2551)
Browse files Browse the repository at this point in the history
Summary:
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __->__ #2551
att

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Pull Request resolved: #2551

Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w -X -d fp32

Reviewed By: andrewor14

Differential Revision: D55221981

Pulled By: jerryzh168

fbshipit-source-id: 8f59df461416d5a5bbdc53f3a6baba331ee3f4ce
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Mar 22, 2024
1 parent 914f642 commit a6aefc0
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@

from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType

from .quantize import (
EmbeddingOnlyInt8QuantHandler,
Int8DynActInt4WeightQuantHandler,
WeightOnlyInt8QuantHandler,
)
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler


IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
Expand Down Expand Up @@ -241,10 +237,9 @@ def quantize(
# Add quantization mode options here: group size, bit width, etc.
return WeightOnlyInt8QuantHandler(model).quantized_model()
elif qmode == "8da4w":
model = Int8DynActInt4WeightQuantHandler(
model,
precision=torch_dtype,
).quantized_model()
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

model = Int8DynActInt4WeightQuantizer(precision=torch_dtype).quantize(model)
print("quantized model:", model)
return model
elif qmode == "8da4w-gptq":
Expand Down

0 comments on commit a6aefc0

Please sign in to comment.