Skip to content

Commit

Permalink
Use quantized tensor in v2 quantizers
Browse files Browse the repository at this point in the history
Signed-off-by: Daemyung Jang <quic_daemyung@quicinc.com>
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
Co-authored-by: Daemyung Jang <quic_daemyung@quicinc.com>
  • Loading branch information
quic-mtuttle and quic-daemyung authored Mar 2, 2024
1 parent 685b641 commit 9fd56c9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
""" Affine quantizers """

import abc
from typing import Optional, Tuple, List, Dict
from typing import Optional, List, Dict
import contextlib
import functools

Expand All @@ -47,6 +47,8 @@

from aimet_torch.experimental.v2.utils import patch_attr, _is_expandable, StatisticsNotFoundError
from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer, MinMaxEncodingAnalyzer
from aimet_torch.experimental.v2.quantization.encodings import AffineEncoding
from aimet_torch.experimental.v2.quantization.quantized_tensor import QuantizedTensor
from aimet_torch.experimental.v2.quantization.quantizers.base import QuantizerBase
from aimet_torch.experimental.v2.quantization.backends import get_backend
from aimet_torch.experimental.v2.utils import ste_round
Expand Down Expand Up @@ -311,7 +313,7 @@ class Quantize(MinMaxQuantizer):
"""
Applies quantization to the input
"""
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor) -> QuantizedTensor:
"""
:param input: Input to quantize
:return: Quantized output and scale/offset associated with it
Expand All @@ -324,8 +326,9 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc

scale = self.get_scale()
offset = self.get_offset()
input_q = get_backend().quantize(input, scale, offset, self.bitwidth)
return input_q, scale, offset
return QuantizedTensor(get_backend().quantize(input, scale, offset, self.bitwidth),
AffineEncoding(scale, offset, self.bitwidth),
lambda x: get_backend().dequantize(torch.Tensor(x), scale, offset))


class QuantizeDequantize(MinMaxQuantizer):
Expand All @@ -352,15 +355,10 @@ class Dequantize(torch.nn.Module):
"""
Applies dequantization to the input
"""
def forward(self,
input: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor) -> torch.Tensor:
def forward(self, input: QuantizedTensor) -> torch.Tensor:
# pylint: disable=no-self-use
"""
:param input: Input to dequantize
:param scale: Quantization scale
:param offset: Quantization offset
:return: Dequantized output
"""
return get_backend().dequantize(input, scale, offset)
return input.dequantize()
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from torch import nn
from torch.optim import SGD, RMSprop, Adagrad, Adam, AdamW
from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
from aimet_torch.experimental.v2.quantization.quantizers.affine import AffineQuantizerBase, Quantize, QuantizeDequantize
from aimet_torch.experimental.v2.quantization.quantizers.affine import AffineQuantizerBase, Quantize, \
QuantizeDequantize, Dequantize
from aimet_torch.experimental.v2.quantization.backends import get_backend


Expand Down Expand Up @@ -148,11 +149,11 @@ def test_quantize_compute_encodings(quantize: Quantize, x: torch.Tensor):
quantize.bitwidth)

with quantize.compute_encodings():
x_int, scale_x, offset_x = quantize(x)
x_int = quantize(x)

assert torch.allclose(x_int, expected_x_int)
assert torch.allclose(scale_x, dynamic_scale)
assert torch.allclose(offset_x, dynamic_offset)
assert torch.allclose(x_int.quantized_repr(), expected_x_int.to(x_int.encoding.dtype))
assert torch.allclose(x_int.encoding.scale, dynamic_scale)
assert torch.allclose(x_int.encoding.offset, dynamic_offset)
assert torch.allclose(quantize.min, dynamic_min)
assert torch.allclose(quantize.max, dynamic_max)
assert torch.allclose(quantize.get_scale(), dynamic_scale)
Expand Down Expand Up @@ -262,7 +263,7 @@ def test_backward_during_compute_encodings(q: AffineQuantizerBase, x: torch.Tens

with q.compute_encodings():
if isinstance(q, Quantize):
output, scale, offset = q(x)
output = q(x)
else:
output = q(x)
output.backward(torch.zeros_like(output))
Expand Down Expand Up @@ -326,12 +327,12 @@ def test_quantize_forward(quantize: Quantize, x: torch.Tensor):
When: forward() invoked
Then: forward() returns parametric quantization output.
"""
output, scale, offset = quantize(x)
output = quantize(x)
expected_output = get_backend().quantize(x,
quantize.get_scale(),
quantize.get_offset(),
quantize.bitwidth)
assert torch.allclose(output, expected_output)
assert torch.allclose(output.quantized_repr(), expected_output.to(output.encoding.dtype))


@pytest.mark.parametrize('quantize_dequantize', [
Expand Down Expand Up @@ -377,7 +378,7 @@ def test_backward(q: AffineQuantizerBase, x: torch.Tensor):
Then: self.min.grad and self.max.grad should be computed
"""
if isinstance(q, Quantize):
output, scale, offset = q(x)
output = q(x)
else:
output = q(x)
output.backward(torch.zeros_like(output))
Expand Down Expand Up @@ -407,7 +408,7 @@ def test_backward_with_no_grad(q, x: torch.Tensor):
x = x.clone().requires_grad_(True)
with torch.no_grad():
if isinstance(q, Quantize):
output, scale, offset = q(x)
output = q(x)
else:
output = q(x)
output = output + x
Expand Down Expand Up @@ -557,7 +558,7 @@ def test_symmetric_learning(q, x, optim_cls):

for _ in range(10):
if isinstance(q, Quantize):
output, scale, offset = q(x)
output = q(x)
else:
output = q(x)
output.backward(torch.randn_like(output))
Expand Down Expand Up @@ -611,7 +612,7 @@ def test_asymmetric_learning(q, x, optim_cls):

for _ in range(10):
if isinstance(q, Quantize):
output, scale, offset = q(x)
output = q(x)
else:
output = q(x)
output.backward(torch.randn_like(output))
Expand Down Expand Up @@ -772,3 +773,19 @@ def test_is_initialized():
res = pickle.dumps(qdq)
qdq = pickle.loads(res)
assert qdq.is_initialized()


@torch.no_grad()
@pytest.mark.parametrize('symmetric', [True, False])
def test_quantize_dequantize_then_quantize_and_dequantize_equality(x, symmetric):
qdq = QuantizeDequantize((1,), 8, symmetric)
q = Quantize((1,), 8, symmetric)
dq = Dequantize()

with qdq.compute_encodings(), q.compute_encodings():
_ = qdq(x)
_ = q(x)

a = qdq(x)
b = dq(q(x))
assert torch.allclose(a, b)

0 comments on commit 9fd56c9

Please sign in to comment.