Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ Quantized-weight Sequential Updating #177

Merged
merged 17 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
Expand All @@ -8,12 +9,15 @@
freeze_module_quantization,
)
from loguru import logger
from pydantic import Field
from pydantic import Field, field_validator
from torch.nn import Module

from llmcompressor.core.state import State
from llmcompressor.modifiers import Modifier, ModifierFactory
from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper
from llmcompressor.modifiers.quantization.gptq.utils import (
GPTQWrapper,
get_output_error,
)
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.context import fix_fsdp_module_name
Expand Down Expand Up @@ -93,7 +97,7 @@ class GPTQModifier(Modifier):
and activation 8 bit quantization on the Linear layers.
"""

sequential_update: Optional[bool] = True
sequential_update: bool = True
targets: Union[str, List[str], None] = None
sequential_targets: Union[str, List[str], None] = None
block_size: int = 128
Expand All @@ -110,6 +114,17 @@ class GPTQModifier(Modifier):
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None

@field_validator("sequential_update", mode="before")
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
logger.warning(
"Not using sequential_update requires allocating all hessians in "
"GPU memory. If you are running into GPU memory issues, consider "
"using sequential_update=True"
)

return value

def on_initialize_structure(self, state: State, **kwargs):
"""
Check the model's quantization state matches that expected by this modifier,
Expand Down Expand Up @@ -240,9 +255,11 @@ def initialize_compression(
args = self._pruning_arguments()
comp_cls = self._compression_class()
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)

# if running sequentially, allocate all hessians now
if not self.sequential_update:
# add all batch processing hooks before the forward pass
compressor.pre_compress()

self.layer_compressors_.append(compressor)

if self.sequential_update:
Expand Down Expand Up @@ -277,21 +294,35 @@ def apply_compression(
)
self.layer_compressors_[0].clear_early_stop()

# empty cache if not using sequential update
if not self.sequential_update:
del intermediates
gc.collect()
torch.cuda.empty_cache()

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====")

# Prune/quantize using GPTQ
if self.sequential_update:
# in sequential mode we run the forward pass for each transformer layer
# one at a time, caching the intermediate outputs between layers
layer_compressor.pre_compress()
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"Calibrating {layer_compressor.name}...")
intermediates = layer_compressor.calibrate_layer(intermediates)
layer_compressor.pre_compress()
unquantized_outputs = layer_compressor.calibrate_layer(intermediates)
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved

layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()

if self.sequential_update:
quantized_outputs = layer_compressor.calibrate_layer(intermediates)
error = get_output_error(unquantized_outputs, quantized_outputs)
logger.info(f"Mean output error from quantization: {error:.3f}")
intermediates = quantized_outputs
del unquantized_outputs

gc.collect()
torch.cuda.empty_cache()

self.model.config.use_cache = forward_pass_use_cache
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# flake8: noqa

from .gptq_wrapper import *
from .helpers import *
51 changes: 51 additions & 0 deletions src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Iterable, List, Tuple, Union

import torch

__all__ = ["get_output_error"]


def get_output_error(
unquantized: List[Tuple[Union[Iterable, torch.Tensor], Any]],
quantized: List[Tuple[Union[Iterable, torch.Tensor], Any]],
) -> torch.Tensor:
"""
Calculate mean l1 loss between weight-unquantized outputs and weight-quantized
outputs

:param unquantized: unquantized-weight outputs
:param quantized: quantized-weight outputs
:return: mean l1 loss between outputs
"""
unquantized_outputs = sum(
[
[output for output in outputs]
if isinstance(outputs, Iterable)
else [outputs]
for outputs, _ in unquantized
],
start=[],
)

quantized_outputs = sum(
[
[output for output in outputs]
if isinstance(outputs, Iterable)
else [outputs]
for outputs, _ in quantized
],
start=[],
)

if len(unquantized_outputs) != len(quantized_outputs):
raise ValueError(
"Number of samples of weight-unquantized and weight-quantized "
"outputs differs"
)

return sum(
[
torch.nn.functional.l1_loss(unq, q)
for unq, q in zip(unquantized_outputs, quantized_outputs)
]
) / len(unquantized_outputs)
11 changes: 6 additions & 5 deletions src/llmcompressor/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.layer_index = layer_index
self.name = name
self.args = args
self.handles = None
self.handles = []
self.early_stop_handle = None
self.modules = {}

Expand Down Expand Up @@ -118,7 +118,6 @@ def tmp(_, inp, out):

return tmp

self.handles = []
for name in self.modules:
self.handles.append(subset[name].register_forward_hook(add_batch(name)))

Expand All @@ -129,14 +128,15 @@ def calibrate_layer(self, intermediates: Tuple[Tuple, Dict]) -> Tuple[Tuple, Dic
:param intermediates: inputs to run through the layer
:return: outputs of the layer
"""
outputs = [None for _ in range(len(intermediates))]
for idx in tqdm(range(len(intermediates))):
args, kwargs = intermediates[idx]
device = get_execution_device(self.layer)
output = self.layer(*tensors_to_device(args, device), **kwargs)
intermediates[idx] = (tensors_to_device(output, "cpu"), kwargs)
outputs[idx] = (tensors_to_device(output, "cpu"), kwargs)
torch.cuda.empty_cache()

return intermediates
return outputs

def post_compress(self):
"""
Expand All @@ -145,6 +145,8 @@ def post_compress(self):
for handle in self.handles:
handle.remove()

self.handles = []

def revert_layer_wrappers(self):
"""
Reverts wrapped root modules back to their original structure
Expand All @@ -171,7 +173,6 @@ def compress_module(module):
logger.info(f"Compressing {full_name}...")
module.compress(**self.args)
module.free()
print("done")
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved

self.layer.apply(compress_module)
torch.cuda.empty_cache()
Expand Down
Loading