Skip to content

Commit

Permalink
Make developer experience better for extending AQT (#749)
Browse files Browse the repository at this point in the history
Make developer experience better
  • Loading branch information
drisspg authored Aug 26, 2024
1 parent c2f4460 commit 37276d6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
15 changes: 10 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def __repr__(self):
# Tensor Subclass Definition #
##############################


class QuantizedLinearNotImplementedError(NotImplementedError):
""" Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """
pass


_QLINEAR_DISPATCH_TABLE = {}
def _register_quantized_linear_dispatch(dispatch_condition, impl):
_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
Expand Down Expand Up @@ -158,8 +164,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items():
if dispatch_condition(input_tensor, weight_tensor, bias):
return impl(input_tensor, weight_tensor, bias)

raise NotImplementedError("No specialized dispatch found for quantized linear op")
raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op")

def __tensor_flatten__(self):
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
Expand Down Expand Up @@ -887,7 +892,7 @@ def _(func, types, args, kwargs):
# make the branches easier to understand in `_quantized_linear_op`
try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
except QuantizedLinearNotImplementedError:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand All @@ -910,7 +915,7 @@ def _(func, types, args, kwargs):
try:
weight_tensor = weight_tensor.t()
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
except QuantizedLinearNotImplementedError:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand All @@ -930,7 +935,7 @@ def _(func, types, args, kwargs):
try:
weight_tensor = weight_tensor.t()
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
except QuantizedLinearNotImplementedError:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMi
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
Expand Down

0 comments on commit 37276d6

Please sign in to comment.