Skip to content

Commit

Permalink
Allow cpu and gpu in int4wo and int4wo-gptq quantizer (pytorch#131)
Browse files Browse the repository at this point in the history
Summary:
att

Test Plan:
verified in torchat

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Apr 12, 2024
1 parent 37ae1d2 commit 13d3ac9
Showing 1 changed file with 26 additions and 31 deletions.
57 changes: 26 additions & 31 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,12 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
c = torch.ops.aten._weight_int4pack_mm(
x.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16)
).to(dtype=x.dtype)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
Expand All @@ -776,8 +781,8 @@ class WeightOnlyInt4Linear(torch.nn.Module):
weight: torch.Tensor

def __init__(
self, in_features: int, out_features: int,
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
self, in_features: int, out_features: int,
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
Expand All @@ -794,23 +799,16 @@ def __init__(

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
if use_cuda:
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
else:
self.register_buffer(
"weight",
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
)
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
Expand All @@ -819,24 +817,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None):
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None):

for name, child in module.named_children():
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
groupsize=groupsize, inner_k_tiles=inner_k_tiles,
))
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func)
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func)

class Int4WeightOnlyQuantizer(Quantizer):
def __init__(
self,
groupsize: int = 256,
padding_allowed: bool = True,
inner_k_tiles: Optional[int] = 8,
device: torch.device = torch.device("cuda"),
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
Expand All @@ -845,6 +844,7 @@ def __init__(
self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
self.device: torch.device = device

@torch.no_grad()
def _create_quantized_state_dict(
Expand Down Expand Up @@ -885,9 +885,9 @@ def _create_quantized_state_dict(
4, # n_bit
self.groupsize,
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda")
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device)
return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
Expand Down Expand Up @@ -916,12 +916,14 @@ def __init__(
groupsize,
inner_k_tiles=8,
padding_allowed=True,
device: torch.device = torch.device("cuda"),
):
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
self.device = device
self.act_fake_quant_func = None
n_bit = 4
self.get_qparams_func = lambda w: get_groupwise_affine_qparams(
Expand Down Expand Up @@ -956,10 +958,10 @@ def make_names_and_values_dict_func(q, qparams):
new_k = k
# how much we need to pad the weight
delta_k = new_k - q.shape[1]
q = q.to(torch.int32)
q = q.to(torch.int32).to(self.device)
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
scales = qparams[0].to(torch.bfloat16)
zeros = qparams[1].to(torch.bfloat16)
scales = qparams[0].to(torch.bfloat16).to(self.device)
zeros = qparams[1].to(torch.bfloat16).to(self.device)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
Expand All @@ -970,13 +972,12 @@ def make_names_and_values_dict_func(q, qparams):
super().__init__()

def _convert_for_runtime(self, model):
# TODO: temporary path for gpt-fast, will remove later
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
skip_layer_func = self.skip_layer_func,
skip_layer_func=self.skip_layer_func,
)
return model

Expand Down Expand Up @@ -1160,7 +1161,6 @@ def __init__(
scales_precision: torch.dtype = torch.float32,
inner_k_tiles: Optional[int] = None,
_is_gpt_fast: bool = False,
_use_cuda: bool = True,
) -> None:
super().__init__()
if _is_gpt_fast:
Expand All @@ -1169,7 +1169,6 @@ def __init__(
else:
assert inner_k_tiles is None
self._is_gpt_fast = _is_gpt_fast
self._use_cuda = _use_cuda
self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
Expand Down Expand Up @@ -1238,7 +1237,6 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
self._use_cuda,
)
else:
replace_linear_8da4w(
Expand Down Expand Up @@ -1270,10 +1268,8 @@ def __init__(
padding_allowed=True,
precision=torch.float32,
_is_gpt_fast=False,
_use_cuda=True,
):
self._is_gpt_fast = _is_gpt_fast
self._use_cuda = _use_cuda
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
Expand Down Expand Up @@ -1352,7 +1348,6 @@ def _convert_for_runtime(self, model):
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
self._use_cuda,
)
else:
replace_linear_8da4w(
Expand Down

0 comments on commit 13d3ac9

Please sign in to comment.