Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eval script fixes #414

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def test_8da4w_quantizer(self):
def test_8da4w_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand Down Expand Up @@ -338,7 +337,6 @@ def test_8da4w_quantizer_eval(self):
def test_gptq_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down
27 changes: 16 additions & 11 deletions torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ def __init__(
pad_token=0,
device="cpu",
):
super().__init__()
self._tokenizer = tokenizer
try:
super().__init__()
except TypeError:
# lm_eval 0.4.2 removed the default init
super().__init__("gpt2", device="cpu")

self.tokenizer = tokenizer
self._device = torch.device(device)
self.vocab_size = vocab_size
self._max_seq_length = calibration_seq_length
Expand All @@ -74,9 +79,9 @@ def __init__(
@property
def eot_token_id(self):
try:
return self._tokenizer.eos_id()
return self.tokenizer.eos_id()
except:
return self._tokenizer.eos_id
return self.tokenizer.eos_id

@property
def max_length(self):
Expand All @@ -96,16 +101,16 @@ def device(self):

def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self._tokenizer.encode(string)
if hasattr(self._tokenizer, "bos_id"):
tokens = self.tokenizer.encode(string)
if hasattr(self.tokenizer, "bos_id"):
try:
tokens = [self._tokenizer.bos_id()] + tokens
tokens = [self.tokenizer.bos_id()] + tokens
except:
tokens = [self._tokenizer.bos_id] + tokens
tokens = [self.tokenizer.bos_id] + tokens
return tokens

def tok_decode(self, tokens):
decoded = self._tokenizer.decode(tokens)
decoded = self.tokenizer.decode(tokens)
return decoded

def add_input(self, args):
Expand Down Expand Up @@ -185,9 +190,9 @@ def __init__(
input_prep_func=None,
device="cuda"
):
super().__init__(None, None)
super().__init__(tokenizer, None)
self._model = model
self._tokenizer = tokenizer
# self.tokenizer = tokenizer
self._device = torch.device(device)
self._max_seq_length = max_seq_length

Expand Down
13 changes: 7 additions & 6 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

)
from torchao.quantization.quant_api import (
quantize, int4wo, int8wo, int8da_int8w, unwrap_tensor_subclass
quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass

)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder

from tokenizer import get_tokenizer
import time
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from model import prepare_inputs_for_model
from torchao._models.llama.model import prepare_inputs_for_model

torch._inductor.config.fx_graph_cache = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
Expand Down Expand Up @@ -60,17 +60,18 @@ def run_evaluation(

if quantization:
if "int8wo" in quantization:
quantize(model, int8wo())
quantize(model, int8_weight_only())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be compatible with torch 2.3 and below? if so we could define similar helpers:

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_weight_only())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)
def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_dynamic_activation_int8_weight())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)
def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4_weight_only())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)

Copy link
Contributor Author

@HDCharles HDCharles Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think its mostly for our own testing, not sure if that's needed

if "int8dq" in quantization:
quantize(model, int8da_int8w())
quantize(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization and not "gptq" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize(model, int4wo(groupsize=groupsize))
quantize(model.to(device), int4_weight_only(group_size=groupsize))
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
assert "cuda" in device, "int4 gptq quantization only works on cuda"
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
Expand All @@ -83,7 +84,7 @@ def run_evaluation(
calibration_limit,
).get_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize)
quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs).to(device)
else:
Expand Down
16 changes: 8 additions & 8 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from model import Transformer, prepare_inputs_for_model
from tokenizer import get_tokenizer
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
Expand Down Expand Up @@ -189,21 +189,21 @@ def main(
if quantization:
from torchao.quantization.quant_api import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we dedup the quant code in eval and generate.py?

Copy link
Contributor Author

@HDCharles HDCharles Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only a bit, its probably more trouble that its worth given the differences and needing to handle autoquant vs gptq ...etc

quantize,
int8wo,
int8da_int8w,
int4wo,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
autoquant,
unwrap_tensor_subclass
)

if "int8wo" in quantization:
quantize(model, int8wo())
quantize(model, int8_weight_only())
if "int8dq" in quantization:
quantize(model, int8da_int8w())
quantize(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize(model, int4wo(groupsize=groupsize))
quantize(model, int4_weight_only(groupsize=groupsize))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is group_size since last update I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll fix it in another PR

if "autoquant" == quantization:
model = autoquant(model, manual=True)

Expand Down
4 changes: 4 additions & 0 deletions torchao/_models/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))
self.bos_token_id = self.bos_id()
self.eos_token_id = self.eos_id()

def encode(self, text):
return self.processor.EncodeAsIds(text)
Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(self, model_path):
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
self.bos_token_id = self.bos_id()
self.eos_token_id = self.eos_id()

def encode(self, text):
return self.model.encode(text)
Expand Down
11 changes: 7 additions & 4 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16]
# needed for GPTQ on the torchao llama model
import torchao
torchao._models.llama.model.use_index_put_for_kv_cache = True
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
Expand All @@ -95,7 +98,7 @@ def __init__(
self.groupsize = groupsize
self.inputs = inputs
self.gptq_done = False
self.debug = False
self.debug = True

def configure_quantization_mode(
self,
Expand Down Expand Up @@ -672,9 +675,9 @@ def quantize(
class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer):
def __init__(
self,
blocksize,
percdamp,
groupsize,
blocksize=128,
percdamp=0.01,
groupsize=64,
inner_k_tiles=8,
padding_allowed=True,
device: torch.device = torch.device("cuda"),
Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
Expand Down
16 changes: 11 additions & 5 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
except:
from torch._inductor.runtime.runtime_utils import do_bench

__all__ = [
"AutoQuantizableLinearWeight",
"autoquant",
]


aten = torch.ops.aten

AUTOQUANT_CACHE = {}
Expand Down Expand Up @@ -382,11 +388,11 @@ def from_float(cls, weight):
AQInt8DynamicallyQuantizedLinearWeight,
]

def change_linears_to_autoquantizable(model, **kwargs):
def _change_linears_to_autoquantizable(model, **kwargs):
"""
Converts all linear weight tensors to the
AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed
by running the model and then calling change_autoquantizable_to_quantized
by running the model and then calling _change_autoquantizable_to_quantized
"""
from torchao.quantization.quant_api import _is_linear
filter_fn = kwargs.pop("filter_fn", _is_linear)
Expand All @@ -401,7 +407,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
filter_fn if filter_fn is not None else _is_linear,
)

def change_autoquantizable_to_quantized(model, **kwargs):
def _change_autoquantizable_to_quantized(model, **kwargs):
"""
Converts AutoQuantizableLinearWeight tensor subclasses
to various quantized/non-quantized tensor subclasses depending
Expand Down Expand Up @@ -490,7 +496,7 @@ def autoquant(

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
change_linears_to_autoquantizable(
_change_linears_to_autoquantizable(
model,
filter_fn=filter_fn,
qtensor_class_list=qtensor_class_list,
Expand Down Expand Up @@ -531,7 +537,7 @@ def autoquant_prehook(module, args, kwargs):
# note the torch.compile wrapper (eval_frame) moves the assignment of any assigned
# attributes to the inner model that didn't exist before, so we have to call delattr on the inner model
def finalize_autoquant():
change_autoquantizable_to_quantized(
_change_autoquantizable_to_quantized(
real_model,
**aq_kwargs,
)
Expand Down
Loading