Skip to content

Commit

Permalink
enhence sparsegpt (#505)
Browse files Browse the repository at this point in the history
* update

* fix bug

* fix bug

* update opt

* add memory efficient forward for opt

* support to set device for pruning

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
  • Loading branch information
3 people committed Apr 12, 2023
1 parent 316977b commit bb4e3a9
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 79 deletions.
6 changes: 5 additions & 1 deletion mmrazor/implementations/pruning/sparse_gpt/mutator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from mmrazor.utils import print_log
Expand All @@ -19,11 +20,14 @@ def end_init_hessian(self):
for module in self.sparse_ops:
module.end_init_hessian()

def prune_24(self):
def prune_24(self, device=torch.device('cuda:0')):
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
module = module.to(device)
error = module.prune_24()
print_log(f'prune {name} success \t error = {error}')
module.to(original_device)
except Exception as e:
print_log(f'prune {name} failed as {e}')

Expand Down
12 changes: 11 additions & 1 deletion mmrazor/implementations/pruning/sparse_gpt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,14 @@ def __init__(self, *args, **kwargs) -> None:
self._sparse_gpt_mix_in_init()

@classmethod
def convert_from(cls, module: nn.Linear):
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)

device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
new_module = new_module.to(device).to(dtype)

return new_module


Expand All @@ -206,6 +211,11 @@ def __init__(self, *args, **kwargs) -> None:
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)

device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
new_module = new_module.to(device).to(dtype)

return new_module

def format_input(self, input: torch.Tensor):
Expand Down
55 changes: 55 additions & 0 deletions mmrazor/implementations/pruning/sparse_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mmrazor.models.architectures.dynamic_ops import DynamicMixin
from mmrazor.models.utils import get_module_device
from mmrazor.utils import print_log


class ModuleProtocol(Protocol):
Expand Down Expand Up @@ -44,3 +45,57 @@ def replace_op(model: nn.Module, name: str, module: nn.Module):
new_module = dynamicop_map[type(module)].convert_from(module).to(
get_module_device(module))
replace_op(model, name, new_module)


def register_efficient_forward_hook(module: nn.Module,
device=torch.device('cuda:0')):

def forward_pre_hook(module: nn.Module, input):
module.to(device)

def forward_hook(module: nn.Module, input, output):
module.to('cpu')
torch.cuda.empty_cache()

h1 = module.register_forward_pre_hook(forward_pre_hook)
h2 = module.register_forward_hook(forward_hook)
return [h1, h2]


def enable_efficient_forward(model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]):
handles = []
blocks = []
for name, module in model.named_children():
if type(module) in wrap_modules or len(module._parameters) != 0 or len(
module._buffers) != 0:
handles_ = register_efficient_forward_hook(module, device)
blocks_ = [name]
else:
handles_, blocks_ = enable_efficient_forward(
module, device, wrap_modules)
handles += handles_
blocks += blocks_
return handles, blocks


class memory_efficient_forward:

def __init__(self,
model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]) -> None:
self.model = model
self.device = device
self.wrap_modules = wrap_modules

def __enter__(self, ):
handles, blocks = enable_efficient_forward(self.model, self.device,
self.wrap_modules)
print_log(f'enable memory efficient forward for {blocks}')
self.handlers = handles

def __exit__(self, exc_type, exc_value, exc_traceback):
for h in self.handlers:
h.remove()
87 changes: 52 additions & 35 deletions projects/sparse_gpt/llm/opt_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import torch
import torch.nn as nn
from transformers import OPTForCausalLM
from transformers.models.opt.modeling_opt import OPTDecoderLayer

from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward

has_wandb = False

Expand All @@ -24,35 +28,47 @@ def skip(*args, **kwargs):
return model


def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens


@torch.no_grad()
def opt_eval(model: OPTForCausalLM,
testenc,
dev,
dataset: str,
dev=torch.device('cuda:0'),
batch_size=64,
log_wandb: bool = False):
print('Evaluating ...')

testenc: torch.Tensor = testenc.input_ids # type: ignore
nsamples = testenc.numel() // model.seqlen
seqlen = model.seqlen

testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N

use_cache = model.config.use_cache
model.config.use_cache = False
nlls = []

for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
out = model(batch)[0] # 1
for batch in torch.split(testenc, batch_size):
B = batch.shape[0]

batch = batch.to(dev)
out: torch.Tensor = model(batch)[0] # 1

shift_logits = out[:, :-1, :].contiguous() # 1 N C
shift_labels = batch[:, 1:] # 1 N
shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
shift_labels = batch[:, 1:].flatten() # (B N)

loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
loss = loss_fct(shift_logits, shift_labels)
neg_log_likelihood = loss.float() * seqlen * B
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel()))
print(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache

Expand All @@ -62,22 +78,21 @@ def opt_infer(
model: OPTForCausalLM,
testenc,
dev,
num_samples=128,
batch_size=64,
):
print('Infer ...')

testenc: torch.Tensor = testenc.input_ids # type: ignore
nsamples = testenc.numel() // model.seqlen
seqlen = model.seqlen

testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N

model.config.use_cache = False

for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
for batch in torch.split(testenc, batch_size):
batch = batch.to(dev)
_ = model(batch)[0] # 1

if i > num_samples:
break


if __name__ == '__main__':
import argparse
Expand Down Expand Up @@ -107,23 +122,25 @@ def opt_infer(

model = get_opt(args.model)
model.eval()
model = model.cuda()
print('load model over')
DEV = torch.device('cuda:0')

dataloader, testloader = get_loaders(
'c4', seed=args.seed, model=args.model, seqlen=model.seqlen)

from mmrazor.implementations.pruning import sparse_gpt
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)

mutator.start_init_hessian()
opt_infer(model, testloader, DEV, num_samples=128)
mutator.end_init_hessian()
mutator.prune_24()

for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print(dataset)
opt_eval(model, testloader, DEV, dataset)
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(
model.model.decoder)

with memory_efficient_forward(model, wrap_modules=[OPTDecoderLayer]):

mutator.start_init_hessian()
opt_infer(model, testloader, DEV)
mutator.end_init_hessian()
mutator.prune_24()

for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print(dataset)
opt_eval(model, testloader, DEV)
32 changes: 0 additions & 32 deletions projects/sparse_gpt/pipe.py

This file was deleted.

40 changes: 30 additions & 10 deletions projects/sparse_gpt/torch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from pipe import sparse_model
from torch.utils.data import DataLoader

from mmrazor.implementations.pruning import sparse_gpt


def get_dataloaders(batch_size, n_workers, path=''):
normalize = transforms.Normalize(
Expand Down Expand Up @@ -50,13 +51,14 @@ def get_dataloaders(batch_size, n_workers, path=''):
return dataloader_train, dataloader_test


def eval(model: nn.Module, dataloader_test: DataLoader):
@torch.no_grad()
def eval(model: nn.Module,
dataloader_test: DataLoader,
device=torch.device('cuda:0')):

total = 0
correct = 0

device = next(model.parameters()).device

model.eval()
with torch.no_grad():
for x, y in dataloader_test:
Expand All @@ -72,15 +74,33 @@ def eval(model: nn.Module, dataloader_test: DataLoader):
return acc


# sparse_model(model, train_loader, 512)
@torch.no_grad()
def infer(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_batchs=256,
device=torch.device('cuda:0')):
model.eval()
with torch.no_grad():
accumulate_batch = 0
for x, _ in dataloader:
x = x.to(device)
model(x)
B = x.shape[0]
accumulate_batch += B
if accumulate_batch > num_batchs:
break


if __name__ == '__main__':
# sparse_model(model, train_loader, 512)
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(128, 4, 'data/imagenet_torch')

model = model.cuda()
model = sparse_model(model, test_loader, num_batchs=512)
model = torchvision.models.resnet18(pretrained=True).cuda()
train_loader, test_loader = get_dataloaders(256, 4, 'data/imagenet_torch')

mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)
mutator.start_init_hessian()
infer(model, test_loader, num_batchs=512)
mutator.end_init_hessian()
mutator.prune_24()

print('start evaluation')
model = model.cuda()
Expand Down

0 comments on commit bb4e3a9

Please sign in to comment.