Skip to content

Commit

Permalink
final tests
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Jun 21, 2024
1 parent c499891 commit 98aeee5
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 29 deletions.
2 changes: 0 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def test_8da4w_quantizer(self):
def test_8da4w_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand Down Expand Up @@ -338,7 +337,6 @@ def test_8da4w_quantizer_eval(self):
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down
27 changes: 16 additions & 11 deletions torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ def __init__(
pad_token=0,
device="cpu",
):
super().__init__()
self._tokenizer = tokenizer
try:
super().__init__()
except TypeError:
# lm_eval 0.4.2 removed the default init
super().__init__("gpt2", device="cpu")

self.tokenizer = tokenizer
self._device = torch.device(device)
self.vocab_size = vocab_size
self._max_seq_length = calibration_seq_length
Expand All @@ -74,9 +79,9 @@ def __init__(
@property
def eot_token_id(self):
try:
return self._tokenizer.eos_id()
return self.tokenizer.eos_id()
except:
return self._tokenizer.eos_id
return self.tokenizer.eos_id

@property
def max_length(self):
Expand All @@ -96,16 +101,16 @@ def device(self):

def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self._tokenizer.encode(string)
if hasattr(self._tokenizer, "bos_id"):
tokens = self.tokenizer.encode(string)
if hasattr(self.tokenizer, "bos_id"):
try:
tokens = [self._tokenizer.bos_id()] + tokens
tokens = [self.tokenizer.bos_id()] + tokens
except:
tokens = [self._tokenizer.bos_id] + tokens
tokens = [self.tokenizer.bos_id] + tokens
return tokens

def tok_decode(self, tokens):
decoded = self._tokenizer.decode(tokens)
decoded = self.tokenizer.decode(tokens)
return decoded

def add_input(self, args):
Expand Down Expand Up @@ -185,9 +190,9 @@ def __init__(
input_prep_func=None,
device="cuda"
):
super().__init__(None, None)
super().__init__(tokenizer, None)
self._model = model
self._tokenizer = tokenizer
# self.tokenizer = tokenizer
self._device = torch.device(device)
self._max_seq_length = max_seq_length

Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tokenizer import get_tokenizer
import time
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from model import prepare_inputs_for_model
from torchao._models.llama.model import prepare_inputs_for_model

torch._inductor.config.fx_graph_cache = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from model import Transformer, prepare_inputs_for_model
from tokenizer import get_tokenizer
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
Expand Down
4 changes: 4 additions & 0 deletions torchao/_models/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))
self.bos_token_id = self.bos_id()
self.eos_token_id = self.eos_id()

def encode(self, text):
return self.processor.EncodeAsIds(text)
Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(self, model_path):
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
self.bos_token_id = self.bos_id()
self.eos_token_id = self.eos_id()

def encode(self, text):
return self.model.encode(text)
Expand Down
9 changes: 3 additions & 6 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
)
aten = torch.ops.aten

# need this to fix the model so it works for GPTQ
import torchao
from torchao._models.llama.model import use_index_put_for_kv_cache
torchao._models.llama.model.use_index_put_for_kv_cache = True

if not _lm_eval_available:
logging.info("lm_eval is not installed, GPTQ may not be usable")

Expand Down Expand Up @@ -86,7 +81,9 @@ def __init__(

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16]
model.cpu()(*one_input)
# needed for GPTQ on the torchao llama model
import torchao
torchao._models.llama.model.use_index_put_for_kv_cache = True
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
Expand Down
16 changes: 11 additions & 5 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
except:
from torch._inductor.runtime.runtime_utils import do_bench

__all__ = [
"AutoQuantizableLinearWeight",
"autoquant",
]


aten = torch.ops.aten

AUTOQUANT_CACHE = {}
Expand Down Expand Up @@ -382,11 +388,11 @@ def from_float(cls, weight):
# TODO this gets picked in places where it makes perf worse, why?
]

def change_linears_to_autoquantizable(model, **kwargs):
def _change_linears_to_autoquantizable(model, **kwargs):
"""
Converts all linear weight tensors to the
AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed
by running the model and then calling change_autoquantizable_to_quantized
by running the model and then calling _change_autoquantizable_to_quantized
"""
from torchao.quantization.quant_api import _is_linear
filter_fn = kwargs.pop("filter_fn", _is_linear)
Expand All @@ -401,7 +407,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
filter_fn if filter_fn is not None else _is_linear,
)

def change_autoquantizable_to_quantized(model, **kwargs):
def _change_autoquantizable_to_quantized(model, **kwargs):
"""
Converts AutoQuantizableLinearWeight tensor subclasses
to various quantized/non-quantized tensor subclasses depending
Expand Down Expand Up @@ -461,7 +467,7 @@ def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST,
# autoquantization
def autoquant_prehook(module, args, kwargs):
module.forward_log_only(*args, **kwargs)
change_autoquantizable_to_quantized(
_change_autoquantizable_to_quantized(
module,
**aq_kwargs,
)
Expand All @@ -470,7 +476,7 @@ def autoquant_prehook(module, args, kwargs):

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
change_linears_to_autoquantizable(
_change_linears_to_autoquantizable(
model,
filter_fn=filter_fn,
qtensor_class_list=qtensor_class_list,
Expand Down

0 comments on commit 98aeee5

Please sign in to comment.