Skip to content

Commit

Permalink
Merge branch 'main' into tinygemm
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored May 8, 2024
2 parents 7e53a5e + f6d56ca commit 65931e3
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 33 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ torch._inductor.config.use_mixed_mm = True
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to recover performance
model = torch.compile(model, mode='max-autotune')
model(input)
# perform autoquantization and compilation
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
q_model(input)
```

### Sparsity
Expand Down
63 changes: 55 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
torch.nn.ReLU(),
).to(device).to(dtype)
out = model(example_input)
torchao.autoquant(model, example_input)
torchao.autoquant(model)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)
Expand All @@ -1400,7 +1400,9 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
Expand All @@ -1414,15 +1416,60 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
torch.nn.ReLU(),
).to(device).to(dtype)
example_input = torch.randn(m1, k, device=device, dtype=dtype)
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
torchao.quantization.change_linears_to_autoquantizable(model)
out=model(example_input)
model(example_input2)
torchao.quantization.change_autoquantizable_to_quantized(model)
out2 = model(example_input)
example_input2 = torch.randn(m1, k, device=device, dtype=dtype)
out = model(example_input)

mod = torchao.autoquant(torch.compile(model))
mod.forward_log_only(example_input)
mod(example_input2)

out2 = mod(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(1, 1, 128, 128),
(1, 32, 128, 128),
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
if m1 == 1 or m2 == 1:
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")

class NeedsKwargs(torch.nn.Module):
def __init__(self):
super().__init__()
self.rel = torch.nn.ReLU()
self.lin = torch.nn.Linear(k,n)

def forward(self, x, y):
x = self.rel(x)
z = self.lin(x + y)
return z

model = NeedsKwargs().to(device).to(dtype)
example_input = {
"x": torch.randn(m1, k, device=device, dtype=dtype),
"y": torch.randn(m1, k, device=device, dtype=dtype),
}
out = model(**example_input)

mod = torchao.autoquant(torch.compile(model))
mod.forward_log_only(**example_input)
mod(**example_input)

out2 = mod(**example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

class TestAOTI(unittest.TestCase):
@parameterized.expand(
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ torch._inductor.config.use_mixed_mm = True
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))
# perform autoquantization and torch.compile
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
# pass in an input which is used in order to pick fastest quantization operations
# and apply torch compilation.
model(input)
```

Expand Down Expand Up @@ -167,6 +167,6 @@ model(input)

## Notes

1. APIs have been hardware tested on A100 and T4(colab)
1. APIs have been hardware tested on A100 and T4(colab)
2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
80 changes: 67 additions & 13 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
update_cache(q_cls, shapes_and_dtype, res)

@torch.no_grad()
def to_quantized(self, error_on_unseen, **kwargs):
if error_on_unseen and self.logged_data == {}:
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
Expand Down Expand Up @@ -123,7 +124,7 @@ def count_shapes(self, do_print=True):
torch._dynamo.reset()
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
if shape_count is not None and shape_count > 1:
print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
if best_time >= cur_time:
best_time = cur_time
best_cls = q_cls
Expand Down Expand Up @@ -176,6 +177,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if func is aten.detach.default:
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))

@torch.no_grad()
def do_autoquant_bench(op, *args, **kwargs):
"""
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
Expand Down Expand Up @@ -335,6 +337,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
"""
from torchao.quantization.quant_api import _is_linear
filter_fn = kwargs.pop("filter_fn", _is_linear)
_ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
kwargs["mode"] = kwargs.get("mode", ["relu", None])
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
Expand Down Expand Up @@ -374,20 +377,71 @@ def change_autoquantizable_to_quantized(model, **kwargs):
torch._dynamo.reset()

@torch.no_grad()
def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs):
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs):
"""
Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape
across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer
and applies that type of quantization.
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
model and then letting the torch.compile run/tracing occur.
Example usage::
torchao.autoquant(torch.compile(model))
model(*example_input)
"""
if filter_fn is None:
from torchao.quantization.quant_api import _is_linear
filter_fn = _is_linear
# the hook we will use to intercept the model forward and perform
# autoquantization
def autoquant_prehook(module, args, kwargs):
module.forward_log_only(*args, **kwargs)
change_autoquantizable_to_quantized(
module,
**aq_kwargs,
)
module.clean_up_autoquant_hooks_and_attrs()
return args, kwargs

# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
change_linears_to_autoquantizable(
model,
filter_fn=filter_fn,
qtensor_class_list=qtensor_class_list,
mode=mode,
**aq_kwargs
)

# access actual model of torch.compile wrapper if needed
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
real_model = model._orig_mod
else:
real_model = model

# we need a consistent way to run the model which bypasses both
# A) the torch.compile tracing (so we need to run the inner model directly)
# B) the autoquant_prehook we're about to register (so we call forward directly)
model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs)

# the autoquant_prehook intercepts the forward call and performs autoquantization
# and then deletes the hook. if model is a torch.compile wrapper, it then
# does the tracing/compile since the prehook is naturally followed by the normal.
# model run.
handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True)

# note the torch.compile wrapper eval_frame moved the assignment of any assigned
# attributes to the inner model, so we have to call delattr on the inner model
def clean_up_autoquant_hooks_and_attrs():
try:
handle.remove()
delattr(real_model, "clean_up_autoquant_hooks_and_attrs")
delattr(real_model, "forward_log_only")
except:
pass
model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs

change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs)
if not isinstance(example_input, (tuple, list)):
assert isinstance(example_input, torch.Tensor)
# if example input was provided, check it and run it
if isinstance(example_input, torch.Tensor):
example_input = [example_input]
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
if isinstance(example_input, (tuple, list)):
model(*example_input)

return model
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .autoquant import autoquant


__all__ = [
Expand All @@ -46,7 +47,8 @@
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer"
"Int4WeightOnlyQuantizer",
"autoquant"
]

if TORCH_VERSION_AFTER_2_3:
Expand Down

0 comments on commit 65931e3

Please sign in to comment.