diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index bd82f36092..91e604c8cf 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -764,7 +764,12 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + c = torch.ops.aten._weight_int4pack_mm( + x.to(torch.bfloat16), + weight_int4pack, + groupsize, + scales_and_zeros.to(torch.bfloat16) + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -776,8 +781,8 @@ class WeightOnlyInt4Linear(torch.nn.Module): weight: torch.Tensor def __init__( - self, in_features: int, out_features: int, - bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + self, in_features: int, out_features: int, + bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) @@ -794,23 +799,16 @@ def __init__( assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" - if use_cuda: - self.register_buffer( - "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) - ) - else: - self.register_buffer( - "weight", - torch.empty((out_features, in_features // 2), dtype=torch.uint8) - ) + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) self.register_buffer( "scales_and_zeros", torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) ) def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(torch.bfloat16) if self.padding: import torch.nn.functional as F input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) @@ -819,17 +817,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None): +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): for name, child in module.named_children(): if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: setattr(module, name, WeightOnlyInt4Linear( child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + groupsize=groupsize, inner_k_tiles=inner_k_tiles, )) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func) + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func) class Int4WeightOnlyQuantizer(Quantizer): def __init__( @@ -837,6 +835,7 @@ def __init__( groupsize: int = 256, padding_allowed: bool = True, inner_k_tiles: Optional[int] = 8, + device: torch.device = torch.device("cuda"), ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] @@ -845,6 +844,7 @@ def __init__( self.inner_k_tiles = inner_k_tiles self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed + self.device: torch.device = device @torch.no_grad() def _create_quantized_state_dict( @@ -885,9 +885,9 @@ def _create_quantized_state_dict( 4, # n_bit self.groupsize, ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda") - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda") + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device) return cur_state_dict def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: @@ -916,12 +916,14 @@ def __init__( groupsize, inner_k_tiles=8, padding_allowed=True, + device: torch.device = torch.device("cuda"), ): self.blocksize = blocksize self.percdamp = percdamp self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding_allowed = padding_allowed + self.device = device self.act_fake_quant_func = None n_bit = 4 self.get_qparams_func = lambda w: get_groupwise_affine_qparams( @@ -956,10 +958,10 @@ def make_names_and_values_dict_func(q, qparams): new_k = k # how much we need to pad the weight delta_k = new_k - q.shape[1] - q = q.to(torch.int32) + q = q.to(torch.int32).to(self.device) final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) - scales = qparams[0].to(torch.bfloat16) - zeros = qparams[1].to(torch.bfloat16) + scales = qparams[0].to(torch.bfloat16).to(self.device) + zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) # how many new groups we need for padded weight delta_groups = new_k // groupsize - scales_and_zeros.shape[0] @@ -970,13 +972,12 @@ def make_names_and_values_dict_func(q, qparams): super().__init__() def _convert_for_runtime(self, model): - # TODO: temporary path for gpt-fast, will remove later replace_linear_int4( model, self.groupsize, self.inner_k_tiles, self.padding_allowed, - skip_layer_func = self.skip_layer_func, + skip_layer_func=self.skip_layer_func, ) return model @@ -1160,7 +1161,6 @@ def __init__( scales_precision: torch.dtype = torch.float32, inner_k_tiles: Optional[int] = None, _is_gpt_fast: bool = False, - _use_cuda: bool = True, ) -> None: super().__init__() if _is_gpt_fast: @@ -1169,7 +1169,6 @@ def __init__( else: assert inner_k_tiles is None self._is_gpt_fast = _is_gpt_fast - self._use_cuda = _use_cuda self.inner_k_tiles = inner_k_tiles self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed @@ -1238,7 +1237,6 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: self.groupsize, self.inner_k_tiles, self.padding_allowed, - self._use_cuda, ) else: replace_linear_8da4w( @@ -1270,10 +1268,8 @@ def __init__( padding_allowed=True, precision=torch.float32, _is_gpt_fast=False, - _use_cuda=True, ): self._is_gpt_fast = _is_gpt_fast - self._use_cuda = _use_cuda self.blocksize = blocksize self.percdamp = percdamp self.groupsize = groupsize @@ -1352,7 +1348,6 @@ def _convert_for_runtime(self, model): self.groupsize, self.inner_k_tiles, self.padding_allowed, - self._use_cuda, ) else: replace_linear_8da4w(