-
Notifications
You must be signed in to change notification settings - Fork 177
/
utils.py
161 lines (131 loc) · 6.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import unittest
import functools
import copy
import torch
import torchao
from torch.testing._internal import common_utils
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.quant_primitives import MappingType
"""
How to use:
import unittest
from torchao.testing.utils import TorchAOBasicTestCase, copy_tests
from torch.testing._internal import common_utils
# TODO: currently there is no way to set COMMON_DEVICES/COMMON_DTYPES
# we can figure out this a bit later
# change arguments
class MyTestCase(TorchAOBasicTestCase):
TENSOR_SUBCLASS = MyDTypeTensor
FACTOR_FN = to_my_dtype
kwargs = {"target_dtype": torch.uint8}
LINEAR_MIN_SQNR = 30
# copy the instantiated tests
copy_tests(TorchAOBasicTestCase, MyTestCase, "my_test_case")
if __name__ == "__main__":
unittest.main()
"""
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
def copy_tests(
my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
): # noqa: B902
for name, value in my_cls.__dict__.items():
if name.startswith("test_"):
# You cannot copy functions in Python, so we use closures here to
# create objects with different ids. Otherwise, unittest.skip
# would modify all methods sharing the same object id. Also, by
# using a default argument, we create a copy instead of a
# reference. Otherwise, we would lose access to the value.
@functools.wraps(value)
def new_test(self, value=value):
return value(self)
# Copy __dict__ which may contain test metadata
new_test.__dict__ = copy.deepcopy(value.__dict__)
if xfail_prop is not None and hasattr(value, xfail_prop):
new_test = unittest.expectedFailure(new_test)
tf = test_failures and test_failures.get(name)
if tf is not None and suffix in tf.suffixes:
skip_func = (
unittest.skip("Skipped!")
if tf.is_skip
else unittest.expectedFailure
)
new_test = skip_func(new_test)
setattr(other_cls, f"{name}_{suffix}", new_test)
class TorchAOBasicTestCase(common_utils.TestCase):
"""Basic test case for tensor subclasses
"""
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
TENSOR_SUBCLASS = AffineQuantizedTensor
FACTORY_FN = to_affine_quantized_intx
kwargs = {
"mapping_type": MappingType.ASYMMETRIC,
"block_size": (1, 32),
"target_dtype": torch.uint8,
}
# minimum sqnr for linear operation when the weight is quantized to low precision
# with the above setting
LINEAR_MIN_SQNR = 40
def test_flatten_unflatten(self):
hp_tensor = torch.randn(4, 128)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize())
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_hp_tensor_device_dtype(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
@common_utils.parametrize("device1", COMMON_DEVICES)
@common_utils.parametrize("device2", COMMON_DEVICES)
def test_device1_to_device2(self, device1, device2):
"""Note: this should be parametrized with device1 and device2
e.g. device1 = ["cpu", "cuda"], device2 = ["cpu", "cuda"]
"""
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.to(device=device2)
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.to(device2)
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.cuda()
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor.cpu()
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_transpose(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
lp_tensor = lp_tensor.t()
self.assertEqual(lp_tensor.shape, (128, 4))
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_linear(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_linear_compile(self, device, dtype):
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
l.weight = torch.nn.Parameter(lp_tensor)
lp_res = torch.compile(l)(hp_act_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
if __name__ == "__main__":
unittest.main()