Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation Ordering Strategies #121

Merged
merged 80 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
012138a
actorder
horheynm Jul 2, 2024
f88c84e
g_idx fix
horheynm Jul 10, 2024
3211fe1
fix
horheynm Jul 10, 2024
bbbf564
lint
horheynm Jul 10, 2024
8d29f0d
propagagte g_idx with perm
horheynm Jul 11, 2024
89224e9
scratch
horheynm Jul 12, 2024
cb8446d
GPTQ - move calibration of quantiztion params to after hessian calibr…
Jul 18, 2024
d7029a0
no recompute
horheynm Jul 22, 2024
eeff533
clean up
horheynm Jul 22, 2024
842b150
remvoe unwanted code
horheynm Jul 22, 2024
240c39d
draft
horheynm Jul 27, 2024
820d08a
draft
horheynm Jul 31, 2024
564845e
draft
horheynm Aug 1, 2024
6f54737
mimic gptq
horheynm Aug 6, 2024
2cc99bb
permutation seems to be working
kylesayrs Aug 9, 2024
6fe537d
WIP: fails on non-square weights
kylesayrs Aug 9, 2024
6611073
pass perm into quant params calculation
kylesayrs Aug 9, 2024
9077969
works on vllm and loading with identity permutation
kylesayrs Aug 12, 2024
6a1565e
WIP: working pytorch with actorder
kylesayrs Aug 12, 2024
1940df4
able to inference with script and reload, needed to set
kylesayrs Aug 12, 2024
11beac1
remove testing comments
kylesayrs Aug 13, 2024
9456698
remove scripts
kylesayrs Aug 13, 2024
0c773e6
remove dregs
kylesayrs Aug 13, 2024
b6bebc2
merge actorder and group cases
kylesayrs Aug 13, 2024
3bde194
code structuring and cleanup
kylesayrs Aug 13, 2024
758c495
use `refresh_layer_weight_quant_params`
kylesayrs Aug 13, 2024
85fb1ff
update_layer_weight_quant_params reuse
kylesayrs Aug 14, 2024
5b52e9d
deep copy H to allow for future reuse
kylesayrs Aug 14, 2024
9e2cef9
hoist group_size
kylesayrs Aug 16, 2024
e725cc7
remove footer note
kylesayrs Aug 16, 2024
2392b83
apply style
kylesayrs Aug 16, 2024
a5a30e1
fix rebase dreggs
kylesayrs Aug 16, 2024
ca6fc6e
remove extra line
kylesayrs Aug 16, 2024
6f99634
move lines for better grouping
kylesayrs Aug 16, 2024
b726bd6
move for better diff
kylesayrs Aug 16, 2024
2002761
remove extra lines
kylesayrs Aug 16, 2024
0ef0c5b
use getattr to avoid pr dep
kylesayrs Aug 17, 2024
476aed0
Revert "use getattr to avoid pr dep"
kylesayrs Aug 17, 2024
ffb809c
add actorder to docstring
kylesayrs Aug 21, 2024
edc02d4
Merge remote-tracking branch 'origin' into kylesayrs/activation-ordering
kylesayrs Aug 22, 2024
bc49946
do not clone hessian
kylesayrs Aug 22, 2024
99f2286
apply style
kylesayrs Aug 22, 2024
48b36c2
avoid unset g_idx parameter by observing directly
kylesayrs Aug 22, 2024
9550f14
use update_layer_weight_quant_params
kylesayrs Aug 22, 2024
d22ff2e
Merge remote-tracking branch 'origin/main' into kylesayrs/activation-…
kylesayrs Aug 23, 2024
72d919f
Merge branch 'main' into kylesayrs/activation-ordering
kylesayrs Aug 25, 2024
e4d37a6
indent for when quantization_scheme is missing
kylesayrs Aug 25, 2024
cdc8bcd
add actorder e2e test
kylesayrs Aug 25, 2024
1fe188b
do not freeze if initialized from gptq
kylesayrs Aug 27, 2024
b06a103
add get_attr_chain helper function
kylesayrs Aug 27, 2024
f293efd
cleanup and clarify logic
kylesayrs Aug 27, 2024
a99e0da
apply style
kylesayrs Aug 27, 2024
bf915d4
rename to getattr_chain, handle no default case
kylesayrs Aug 27, 2024
66ef96b
out of place type conversion
kylesayrs Aug 27, 2024
98aaf88
Merge remote-tracking branch 'origin/gptq-cleanup' into kylesayrs/act…
kylesayrs Aug 27, 2024
91c877a
account for extra case
kylesayrs Aug 27, 2024
b711e14
remove freeze_quantization argument
kylesayrs Aug 28, 2024
974dbc7
remove fake_quantization case, update debug message
kylesayrs Aug 28, 2024
094e429
remove todo
kylesayrs Aug 28, 2024
582c179
Merge remote-tracking branch 'origin/gptq-cleanup' into kylesayrs/act…
kylesayrs Aug 28, 2024
febb741
correct name
kylesayrs Aug 28, 2024
c74bbbd
wip
kylesayrs Aug 28, 2024
1af4a04
unpermute in both cases
kylesayrs Aug 28, 2024
1fdfdd7
integrate with ct
kylesayrs Aug 30, 2024
8a22e82
Merge remote-tracking branch 'origin/main' into kylesayrs/actorder_st…
kylesayrs Aug 30, 2024
2d3cc54
resolve merge
kylesayrs Aug 30, 2024
9bfeb9f
fix import
kylesayrs Aug 30, 2024
2976c7e
pass args
kylesayrs Aug 30, 2024
e8053dd
write g_idx only for group activation ordering
kylesayrs Aug 30, 2024
143ec4b
unpermute for weight option as well
kylesayrs Aug 30, 2024
bfa3ff6
rename and add tests
kylesayrs Aug 30, 2024
3de429c
comments
kylesayrs Aug 30, 2024
acd44cf
Merge remote-tracking branch 'origin/main' into kylesayrs/actorder_st…
kylesayrs Sep 1, 2024
2696f45
Activation Ordering Tests (#135)
kylesayrs Sep 3, 2024
c6710f8
update enum name, apply style
kylesayrs Sep 3, 2024
c39d945
break out actordering
kylesayrs Sep 3, 2024
d03ba61
Merge remote-tracking branch 'origin/main' into kylesayrs/actorder_st…
kylesayrs Sep 3, 2024
00f2fa0
compute sparsity mask after permutation
kylesayrs Sep 3, 2024
14559a9
use get_observer
kylesayrs Sep 3, 2024
f903305
Merge branch 'main' into kylesayrs/actorder_static_group
kylesayrs Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 78 additions & 43 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import time
from typing import Tuple

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
from compressed_tensors.quantization.observers import MemorylessObserver

from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD
from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
from llmcompressor.utils import getattr_chain
from llmcompressor.utils.metric_logging import (
get_GPU_memory_usage,
Expand Down Expand Up @@ -89,20 +94,20 @@ def compress(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
weight_quant_args = getattr_chain(
self.layer, "quantization_scheme.weights", None
)
args_loc = "quantization_scheme.weights"
weight_quant_args = getattr_chain(self.layer, args_loc, None)
if weight_quant_args is None:
logger.debug(f"Skipping unquantized layer {self.name}...")
return

if is_module_offloaded(self.layer):
self.layer._hf_hook.pre_forward(self.layer)

strategy = weight_quant_args.strategy
actorder = weight_quant_args.actorder
final_shape = self.layer.weight.shape
final_dtype = self.layer.weight.dtype
W = self.layer.weight.data.clone()
from llmcompressor.pytorch.utils.helpers import tensor_sparsity

# standardize shape and dtype
if isinstance(self.layer, nn.Conv2d):
Expand All @@ -111,6 +116,33 @@ def compress(
W.transpose_(0, 1)
W = W.float()

tick = time.time()

if strategy == QuantizationStrategy.GROUP:
# mapping from column index to group index
g_idx = (
torch.arange(self.columns, device=W.device, dtype=torch.int)
// weight_quant_args.group_size
)

if actorder == ActivationOrdering.GROUP:
# permute by activation order first, then update groups
W, self.H, perm = self._apply_activation_ordering(W, self.H)
self._update_quantization_parameters(weight_quant_args, W)

# use identity g_idx (invert permutation later)

elif actorder == ActivationOrdering.WEIGHT:
# update groups first, then permute by activation order
self._update_quantization_parameters(weight_quant_args, W)
W, self.H, perm = self._apply_activation_ordering(W, self.H)

# permute g_idx to maintain identity mapping after unpermutation
g_idx = g_idx[perm]

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

# sparsity mask
sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
Expand All @@ -120,26 +152,6 @@ def compress(
else None
)

tick = time.time()

# consider activation ordering
if weight_quant_args.actorder:
# use hessian to create a permutation of weights
perm = torch.argsort(torch.diag(self.H), descending=True)

# permute weight and hessian
W = W[:, perm]
self.H = self.H[perm][:, perm]

# update quantization parameters for activation ordering
observer = MemorylessObserver(weight_quant_args)
_scale, _zero_point = observer(W)
update_parameter_data(self.layer, _scale, "weight_scale")
update_parameter_data(self.layer, _zero_point, "weight_zero_point")

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

# mask dead hessian values
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
Expand Down Expand Up @@ -176,7 +188,6 @@ def compress(
q = w.clone()

# quantize column
strategy = weight_quant_args.strategy
if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
Expand All @@ -194,16 +205,16 @@ def compress(
elif strategy == QuantizationStrategy.GROUP:
# get the group index for the current column
column_idx = i1 + i
input_dim_group = column_idx // weight_quant_args.group_size
group_index = g_idx[column_idx]

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(weight_quant_args)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
scale[:, group_index],
zero_point[:, group_index],
altered_qargs,
)
else:
Expand Down Expand Up @@ -235,21 +246,22 @@ def compress(
W[:, i2:] -= w_err

if "METRIC" in logger._core.levels.keys():
self.log_metrics(tick, Losses)
self._log_metrics(tick, Losses)

if weight_quant_args.actorder:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]
if strategy == QuantizationStrategy.GROUP:
if actorder == ActivationOrdering.WEIGHT:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

# g_idx describes the group index of the permuted weight
g_idx = torch.tensor(
[i // weight_quant_args.group_size for i in range(self.columns)],
dtype=torch.int,
).to(device=invperm.device)
elif actorder == ActivationOrdering.GROUP:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]
g_idx = g_idx[invperm]

# invert to get the group index of the unpermuted weight
update_parameter_data(self.layer, g_idx[invperm], "weight_g_idx")
# only save g_idx if mapping is not identity
update_parameter_data(self.layer, g_idx, "weight_g_idx")

if isinstance(self.layer, transformers.Conv1D):
W.transpose_(0, 1)
Expand All @@ -272,7 +284,30 @@ def free(self):
delattr(self, "H")
super().free()

def log_metrics(self, start_tick: float, losses: torch.Tensor):
def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor):
"""
Update layer quantization parameters with potentially permuted weight

:param args: quantization arguments
:param W: weight to calculate quantization parameters from
"""
observer = args.get_observer()
_scale, _zero_point = observer(W, g_idx=None)
update_parameter_data(self.layer, _scale, "weight_scale")
update_parameter_data(self.layer, _zero_point, "weight_zero_point")

def _apply_activation_ordering(
self, W: torch.Tensor, H: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Permute weight and hessian in order of greatest outupt activations

:param W: weight to permute
"""
perm = torch.argsort(torch.diag(H), descending=True)
return W[:, perm], H[perm][:, perm], perm

def _log_metrics(self, start_tick: float, losses: torch.Tensor):
"""
Log metrics related to compression algorithm

Expand Down
9 changes: 9 additions & 0 deletions tests/e2e/vLLM/configs/actorder/w4a16_actorder_group.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cadence: "nightly"
test_type: "regression"
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_group.yaml
dataset_id: openai/gsm8k
dataset_config: main
dataset_split: train
scheme: W4A16
save_dir: TinyLlama-1.1B-Chat-v1.0-actorder-group
9 changes: 9 additions & 0 deletions tests/e2e/vLLM/configs/actorder/w4a16_actorder_weight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cadence: "nightly"
test_type: "regression"
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml
dataset_id: openai/gsm8k
dataset_config: main
dataset_split: train
scheme: W4A16
save_dir: TinyLlama-1.1B-Chat-v1.0-actorder-weight
15 changes: 15 additions & 0 deletions tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_group.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "group"
group_size: 128
actorder: "group"
targets: ["Linear"]
15 changes: 15 additions & 0 deletions tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "group"
group_size: 128
actorder: "weight"
targets: ["Linear"]
49 changes: 19 additions & 30 deletions tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from tests.testing_utils import parse_params, requires_gpu, requires_torch
from tests.testing_utils import (
parse_params,
preprocess_tokenize_dataset,
requires_gpu,
requires_torch,
)

try:
from vllm import LLM, SamplingParams
Expand All @@ -17,12 +22,8 @@
except ImportError:
vllm_installed = False

# Defines the file paths to the directories containing the test configs
# for each of the quantization schemes
WNA16 = "tests/e2e/vLLM/configs/WNA16"
FP8 = "tests/e2e/vLLM/configs/FP8"
INT8 = "tests/e2e/vLLM/configs/INT8"
CONFIGS = [WNA16, FP8, INT8]

CONFIGS = "tests/e2e/vLLM/configs"


@requires_gpu
Expand Down Expand Up @@ -50,14 +51,15 @@ class TestvLLM(unittest.TestCase):
model = None
scheme = None
dataset_id = None
dataset_config = None
dataset_split = None
recipe = None
save_dir = None

def setUp(self):
print("========== RUNNING ==============")
print(self.scheme)

self.save_dir = None
self.device = "cuda:0"
self.oneshot_kwargs = {}
self.num_calibration_samples = 256
Expand All @@ -75,35 +77,21 @@ def test_vllm(self):
)
tokenizer = AutoTokenizer.from_pretrained(self.model)

def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}

def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=self.max_seq_length,
truncation=True,
add_special_tokens=False,
)

if self.dataset_id:
ds = load_dataset(self.dataset_id, split=self.dataset_split)
ds = load_dataset(
self.dataset_id, name=self.dataset_config, split=self.dataset_split
)
ds = ds.shuffle(seed=42).select(range(self.num_calibration_samples))
ds = ds.map(preprocess)
ds = ds.map(tokenize, remove_columns=ds.column_names)
ds = preprocess_tokenize_dataset(ds, tokenizer, self.max_seq_length)
self.oneshot_kwargs["dataset"] = ds
self.oneshot_kwargs["max_seq_length"] = self.max_seq_length
self.oneshot_kwargs["num_calibration_samples"] = (
self.num_calibration_samples
)

self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
if self.save_dir is None:
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"

self.oneshot_kwargs["model"] = loaded_model
if self.recipe:
self.oneshot_kwargs["recipe"] = self.recipe
Expand Down Expand Up @@ -141,4 +129,5 @@ def tokenize(sample):
tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e")

def tearDown(self):
shutil.rmtree(self.save_dir)
if self.save_dir is not None:
shutil.rmtree(self.save_dir)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_actorder.yaml"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_actorder_group.yaml"
ppl_threshold: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_actorder_weight.yaml"
ppl_threshold: 20
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_stage:
symmetric: False
strategy: "group"
group_size: 128
actorder: True
actorder: "group"
input_activations: null
output_activations: null
targets: ["Linear"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
test_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head", "model.layers.0.mlp.down_proj"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: False
strategy: "group"
group_size: 128
actorder: "weight"
input_activations: null
output_activations: null
targets: ["Linear"]
GPTQModifier:
block_size: 128
sequential_update: False
Loading
Loading