Skip to content

Commit

Permalink
Add support for AQTLayout, PlainAQTLayout and `TensorCoreTiledAQT…
Browse files Browse the repository at this point in the history
…Layout` (#278)

Summary:
Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support
packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all
should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage)

This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly,
in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different
type of `AQTStorage`.

`AffineQuantizedTensor` will have the following:
- storage_tensor: AQTStorage (can store data of different storage formats)
- storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch)

Test Plan:
python test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 29, 2024
1 parent 62745fc commit 374fec4
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 53 deletions.
8 changes: 4 additions & 4 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def __init__(self, m=64, n=32, k=64):
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self, batch_size=1):
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -450,7 +450,7 @@ def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

groupsize = 32
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
m = quantize(m, get_apply_int8dyn_quant())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
Expand Down
Loading

0 comments on commit 374fec4

Please sign in to comment.