Skip to content

Commit

Permalink
integrate with wanda
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Nov 15, 2024
1 parent 0bc7bae commit 261dadd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 40 deletions.
76 changes: 38 additions & 38 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -9,6 +10,7 @@
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.pytorch.module import (
Expand All @@ -20,7 +22,7 @@
__all__ = ["WandaPruningModifier"]


class WandaPruningModifier(Modifier):
class WandaPruningModifier(Modifier, HooksMixin):
"""
Modifier for applying the one-shot WANDA algorithm to a model
from the paper: https://arxiv.org/abs/2306.11695
Expand Down Expand Up @@ -121,7 +123,8 @@ def initialize_compression(
"Inferring layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_layer_sparsity(dataloader)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
Expand Down Expand Up @@ -224,19 +227,17 @@ def _infer_mask_block_size(self):

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, calibration_dataloader):
acts = _get_activations(self.model, calibration_dataloader)
def _infer_layer_sparsity(self, activations):
wanda = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
wanda[name] = torch.cat([item.flatten().cpu() for item in z])

acts = None
del acts
del activations
torch.cuda.empty_cache()

outlier_ratios = {}
Expand Down Expand Up @@ -268,36 +269,35 @@ def _infer_layer_sparsity(self, calibration_dataloader):
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(
mod, functools.partial(save_acts, name=name), "forward_pre"
)
device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()

@torch.no_grad()
def _get_activations(model, data_loader, nsamples=128):
import functools

model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
else:
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()

hooks = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
hooks.append(
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
)
device = next(model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
model(**batch)
batch = None
torch.cuda.empty_cache()

for h in hooks:
h.remove()
self.remove_hooks()

return acts
return acts
6 changes: 4 additions & 2 deletions src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def register_hook(
hook: Callable[[Any], Any],
hook_type: str,
**kwargs,
):
) -> RemovableHandle:
"""
Registers a hook on a specified module with the option to disable it with
HooksMixin.disable_hooks
Expand All @@ -68,7 +68,9 @@ def wrapped_hook(*args, **kwargs):

handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs)
self._hooks.append(handle)
logger.debug(f"Added {handle} for {self}")
logger.debug(f"{self} added {handle}")

return handle

def remove_hooks(self):
"""Remove all hooks belonging to a modifier"""
Expand Down

0 comments on commit 261dadd

Please sign in to comment.