From 2ac25b6fb64487b36767c0cc5f9a6b148efac69d Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 11 Mar 2020 10:14:02 +0800 Subject: [PATCH 1/9] update modelspeedup --- .../compression/speedup/torch/infer_shape.py | 16 +++--- .../compression/torch/apply_compression.py | 56 ++----------------- 2 files changed, 14 insertions(+), 58 deletions(-) diff --git a/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py b/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py index 701d1f58e6..84a8bc5a5d 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py @@ -324,8 +324,8 @@ def batchnorm2d_mask(module_masks, mask): CoarseMask, CoarseMask The mask of its input tensor, the mask of its output tensor """ - assert 'weight' in mask and 'bias' in mask - sum_mask = mask['weight'] + mask['bias'] + assert 'weight_mask' in mask and 'bias_mask' in mask + sum_mask = mask['weight_mask'] + mask['bias_mask'] nonzero_index = torch.nonzero(sum_mask, as_tuple=True)[0] # infer shape of parameters param_cmask = CoarseMask(num_dim=1) @@ -335,7 +335,7 @@ def batchnorm2d_mask(module_masks, mask): # infer shape of input tensor input_cmask = CoarseMask(num_dim=4) input_cmask.add_index_mask(dim=1, - index=torch.nonzero(mask['weight'], as_tuple=True)[0]) + index=torch.nonzero(mask['weight_mask'], as_tuple=True)[0]) module_masks.set_input_mask(input_cmask) # infer shape of output tensor output_cmask = CoarseMask(num_dim=4) @@ -371,9 +371,9 @@ def convert_to_coarse_mask(mask): LongTensor, CoarseMask, CoarseMask Index of the masked dimension, weight mask, bias mask """ - assert 'weight' in mask - assert isinstance(mask['weight'], torch.Tensor) - weight_mask = mask['weight'] + assert 'weight_mask' in mask + assert isinstance(mask['weight_mask'], torch.Tensor) + weight_mask = mask['weight_mask'] shape = weight_mask.size() ones = torch.ones(shape[1:]).to(weight_mask.device) zeros = torch.zeros(shape[1:]).to(weight_mask.device) @@ -393,8 +393,8 @@ def convert_to_coarse_mask(mask): weight_cmask = CoarseMask(num_dim=4) weight_cmask.add_index_mask(dim=0, index=index) bias_cmask = None - if 'bias' in mask and mask['bias'] is not None: - bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0] + if 'bias_mask' in mask and mask['bias_mask'] is not None: + bias_index = torch.nonzero(mask['bias_mask'], as_tuple=True)[0] assert torch.all(torch.eq(index, bias_index)), \ "bias mask should be consistent with weight mask" bias_cmask = CoarseMask(num_dim=1) diff --git a/src/sdk/pynni/nni/compression/torch/apply_compression.py b/src/sdk/pynni/nni/compression/torch/apply_compression.py index 2531da5039..7305aa9a99 100644 --- a/src/sdk/pynni/nni/compression/torch/apply_compression.py +++ b/src/sdk/pynni/nni/compression/torch/apply_compression.py @@ -18,53 +18,9 @@ def apply_compression_results(model, masks_file): masks_file : str The path of the mask file """ - apply_comp = ApplyCompression(model, masks_file) - apply_comp.compress() - -class ApplyCompression(Pruner): - """ - This class is not to generate masks, but applying existing masks - """ - - def __init__(self, model, masks_file): - """ - Parameters - ---------- - model : torch.nn.module - Model to be masked - masks_file : str - The path of user provided mask file - """ - self.bound_model = model - self.masks = torch.load(masks_file) - for module_name in self.masks: - print('module_name: ', module_name) - config_list = self._build_config() - super().__init__(model, config_list) - - def _build_config(self): - op_names = [] - for module_name in self.masks: - op_names.append(module_name) - return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}] - - def calc_mask(self, layer, config, **kwargs): - """ - Directly return the corresponding mask - - Parameters - ---------- - layer : LayerInfo - The layer to be pruned - config : dict - Pruning configurations for this weight - kwargs : dict - Auxiliary information - - Returns - ------- - dict - Mask of the layer - """ - assert layer.name in self.masks - return self.masks[layer.name] + masks = torch.load(masks_file) + for name, module in model.named_modules(): + if name in masks: + module.weight.data = module.weight.data.mul_(masks[name]['weight_mask']) + if 'bias_mask' in masks[name]: + module.bias.data = module.bias.data.mul_(masks[name]['bias_mask']) \ No newline at end of file From 34cb167316ffd10baf65db99109da4d942771b65 Mon Sep 17 00:00:00 2001 From: quzha Date: Sat, 14 Mar 2020 21:17:45 +0800 Subject: [PATCH 2/9] update --- examples/model_compress/model_speedup.py | 27 +++++++++---------- .../compression/speedup/torch/infer_shape.py | 16 +++++------ .../compression/torch/apply_compression.py | 7 +++-- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index 9d27d98da9..5bbd3563f9 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -9,7 +9,7 @@ from nni.compression.torch import apply_compression_results torch.manual_seed(0) -use_mask = False +use_mask = True def apoz_speedup(masks_file, model_checkpoint): device = torch.device('cuda') @@ -24,7 +24,7 @@ def apoz_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(32): out = model(dummy_input) - #print(out.size(), out) + print(out.size(), out) print('mask elapsed time: ', time.time() - start) return else: @@ -36,7 +36,7 @@ def apoz_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(32): out = model(dummy_input) - #print(out.size(), out) + print(out.size(), out) print('speedup elapsed time: ', time.time() - start) return @@ -53,28 +53,27 @@ def l1filter_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(32): out = model(dummy_input) - #print(out.size(), out) + print(out.size(), out) print('mask elapsed time: ', time.time() - start) return else: #print("model before: ", model) m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) m_speedup.speedup_model() - #print("model after: ", model) + print("model after: ", model) dummy_input = dummy_input.to(device) start = time.time() for _ in range(32): out = model(dummy_input) - #print(out.size(), out) + print(out.size(), out) print('speedup elapsed time: ', time.time() - start) return def fpgm_speedup(masks_file, model_checkpoint): - from fpgm_torch_mnist import Mnist + from model_prune_torch import NaiveModel device = torch.device('cpu') - model = Mnist() + model = NaiveModel() model.to(device) - model.print_conv_filter_sparsity() dummy_input = torch.randn(64, 1, 28, 28) if use_mask: @@ -83,8 +82,8 @@ def fpgm_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(40): out = model(dummy_input) + print(out.size(), out) print('mask elapsed time: ', time.time() - start) - #print(out.size(), out) return else: m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) @@ -93,8 +92,8 @@ def fpgm_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(40): out = model(dummy_input) + print(out.size(), out) print('speedup elapsed time: ', time.time() - start) - #print(out.size(), out) return def slim_speedup(masks_file, model_checkpoint): @@ -110,7 +109,7 @@ def slim_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(32): out = model(dummy_input) - #print(out.size(), out) + print(out.size(), out) print('mask elapsed time: ', time.time() - start) return else: @@ -122,13 +121,13 @@ def slim_speedup(masks_file, model_checkpoint): start = time.time() for _ in range(32): out = model(dummy_input) - #print(out.size(), out) + print(out.size(), out) print('speedup elapsed time: ', time.time() - start) return if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") - parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example") + parser.add_argument("--example_name", type=str, default="fpgm", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model") args = parser.parse_args() diff --git a/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py b/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py index 84a8bc5a5d..701d1f58e6 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py @@ -324,8 +324,8 @@ def batchnorm2d_mask(module_masks, mask): CoarseMask, CoarseMask The mask of its input tensor, the mask of its output tensor """ - assert 'weight_mask' in mask and 'bias_mask' in mask - sum_mask = mask['weight_mask'] + mask['bias_mask'] + assert 'weight' in mask and 'bias' in mask + sum_mask = mask['weight'] + mask['bias'] nonzero_index = torch.nonzero(sum_mask, as_tuple=True)[0] # infer shape of parameters param_cmask = CoarseMask(num_dim=1) @@ -335,7 +335,7 @@ def batchnorm2d_mask(module_masks, mask): # infer shape of input tensor input_cmask = CoarseMask(num_dim=4) input_cmask.add_index_mask(dim=1, - index=torch.nonzero(mask['weight_mask'], as_tuple=True)[0]) + index=torch.nonzero(mask['weight'], as_tuple=True)[0]) module_masks.set_input_mask(input_cmask) # infer shape of output tensor output_cmask = CoarseMask(num_dim=4) @@ -371,9 +371,9 @@ def convert_to_coarse_mask(mask): LongTensor, CoarseMask, CoarseMask Index of the masked dimension, weight mask, bias mask """ - assert 'weight_mask' in mask - assert isinstance(mask['weight_mask'], torch.Tensor) - weight_mask = mask['weight_mask'] + assert 'weight' in mask + assert isinstance(mask['weight'], torch.Tensor) + weight_mask = mask['weight'] shape = weight_mask.size() ones = torch.ones(shape[1:]).to(weight_mask.device) zeros = torch.zeros(shape[1:]).to(weight_mask.device) @@ -393,8 +393,8 @@ def convert_to_coarse_mask(mask): weight_cmask = CoarseMask(num_dim=4) weight_cmask.add_index_mask(dim=0, index=index) bias_cmask = None - if 'bias_mask' in mask and mask['bias_mask'] is not None: - bias_index = torch.nonzero(mask['bias_mask'], as_tuple=True)[0] + if 'bias' in mask and mask['bias'] is not None: + bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0] assert torch.all(torch.eq(index, bias_index)), \ "bias mask should be consistent with weight mask" bias_cmask = CoarseMask(num_dim=1) diff --git a/src/sdk/pynni/nni/compression/torch/apply_compression.py b/src/sdk/pynni/nni/compression/torch/apply_compression.py index 7305aa9a99..0f7949bd73 100644 --- a/src/sdk/pynni/nni/compression/torch/apply_compression.py +++ b/src/sdk/pynni/nni/compression/torch/apply_compression.py @@ -3,7 +3,6 @@ import logging import torch -from .compressor import Pruner logger = logging.getLogger('torch apply compression') @@ -21,6 +20,6 @@ def apply_compression_results(model, masks_file): masks = torch.load(masks_file) for name, module in model.named_modules(): if name in masks: - module.weight.data = module.weight.data.mul_(masks[name]['weight_mask']) - if 'bias_mask' in masks[name]: - module.bias.data = module.bias.data.mul_(masks[name]['bias_mask']) \ No newline at end of file + module.weight.data = module.weight.data.mul_(masks[name]['weight']) + if 'bias' in masks[name]: + module.bias.data = module.bias.data.mul_(masks[name]['bias']) \ No newline at end of file From 9271217244680677e9e139653453aa2eafa8ed04 Mon Sep 17 00:00:00 2001 From: quzha Date: Tue, 17 Mar 2020 09:08:38 +0800 Subject: [PATCH 3/9] update --- examples/model_compress/model_speedup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index 5bbd3563f9..9dc7e997d5 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -127,7 +127,7 @@ def slim_speedup(masks_file, model_checkpoint): if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") - parser.add_argument("--example_name", type=str, default="fpgm", help="the name of pruning example") + parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model") args = parser.parse_args() @@ -138,7 +138,7 @@ def slim_speedup(masks_file, model_checkpoint): slim_speedup(args.masks_file, args.model_checkpoint) elif args.example_name == 'fpgm': if args.masks_file is None: - args.masks_file = 'mask.pth' + args.masks_file = './checkpoints/mask_naive_mnist_fpgm.pth' fpgm_speedup(args.masks_file, args.model_checkpoint) elif args.example_name == 'l1filter': if args.masks_file is None: From 6ae31d6e351eddbdf541d68f0061fb774b6c037a Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 18 Mar 2020 10:17:30 +0800 Subject: [PATCH 4/9] update --- examples/model_compress/model_speedup.py | 6 +++--- src/sdk/pynni/nni/compression/torch/apply_compression.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index 9dc7e997d5..c4ad3fb6bc 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -127,7 +127,7 @@ def slim_speedup(masks_file, model_checkpoint): if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") - parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example") + parser.add_argument("--example_name", type=str, default="apoz", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model") args = parser.parse_args() @@ -142,11 +142,11 @@ def slim_speedup(masks_file, model_checkpoint): fpgm_speedup(args.masks_file, args.model_checkpoint) elif args.example_name == 'l1filter': if args.masks_file is None: - args.masks_file = 'mask_vgg16_cifar10.pth' + args.masks_file = './checkpoints/mask_vgg16_cifar10_l1.pth' l1filter_speedup(args.masks_file, args.model_checkpoint) elif args.example_name == 'apoz': if args.masks_file is None: - args.masks_file = 'mask_vgg16_cifar10.pth' + args.masks_file = './checkpoints/mask_vgg16_cifar10_apoz.pth' apoz_speedup(args.masks_file, args.model_checkpoint) else: raise ValueError('unsupported example_name: {}'.format(args.example_name)) diff --git a/src/sdk/pynni/nni/compression/torch/apply_compression.py b/src/sdk/pynni/nni/compression/torch/apply_compression.py index 0f7949bd73..f9e0b7c3c1 100644 --- a/src/sdk/pynni/nni/compression/torch/apply_compression.py +++ b/src/sdk/pynni/nni/compression/torch/apply_compression.py @@ -21,5 +21,5 @@ def apply_compression_results(model, masks_file): for name, module in model.named_modules(): if name in masks: module.weight.data = module.weight.data.mul_(masks[name]['weight']) - if 'bias' in masks[name]: + if hasattr(module, 'bias') and module.bias is not None and 'bias' in masks[name]: module.bias.data = module.bias.data.mul_(masks[name]['bias']) \ No newline at end of file From b05d23dd61cc9242723614a5dcf0fb92109d2f86 Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 18 Mar 2020 21:44:22 +0800 Subject: [PATCH 5/9] update --- examples/model_compress/model_prune_torch.py | 2 +- examples/model_compress/model_speedup.py | 180 ++++++------------ .../compression/torch/apply_compression.py | 6 +- 3 files changed, 63 insertions(+), 125 deletions(-) diff --git a/examples/model_compress/model_prune_torch.py b/examples/model_compress/model_prune_torch.py index fb64ecf351..b03be0287c 100644 --- a/examples/model_compress/model_prune_torch.py +++ b/examples/model_compress/model_prune_torch.py @@ -127,7 +127,7 @@ def forward(self, x): x = F.max_pool2d(x, 2, 2) x = F.relu(self.bn2(self.conv2(x))) x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4 * 4 * 50) + x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index c4ad3fb6bc..efdb685ed9 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -10,143 +10,79 @@ torch.manual_seed(0) use_mask = True +use_speedup = True +compare_results = True -def apoz_speedup(masks_file, model_checkpoint): - device = torch.device('cuda') - model = VGG(depth=16) - model.to(device) - model.eval() - - dummy_input = torch.randn(64, 3, 32, 32) - if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - #print("model before: ", model) - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) - m_speedup.speedup_model() - #print("model after: ", model) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return - -def l1filter_speedup(masks_file, model_checkpoint): - device = torch.device('cuda') - model = VGG(depth=16) - model.to(device) - model.eval() - - dummy_input = torch.randn(64, 3, 32, 32) - if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - #print("model before: ", model) - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) - m_speedup.speedup_model() - print("model after: ", model) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return - -def fpgm_speedup(masks_file, model_checkpoint): - from model_prune_torch import NaiveModel - device = torch.device('cpu') - model = NaiveModel() - model.to(device) - - dummy_input = torch.randn(64, 1, 28, 28) - if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(40): - out = model(dummy_input) - print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) - m_speedup.speedup_model() - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(40): - out = model(dummy_input) - print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return +config = { + 'apoz': { + 'model_name': 'vgg16', + 'device': 'cuda', + 'input_shape': [64, 3, 32, 32], + 'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth' + }, + 'l1filter': { + 'model_name': 'vgg16', + 'device': 'cuda', + 'input_shape': [64, 3, 32, 32], + 'masks_file': './checkpoints/mask_vgg16_cifar10_l1.pth' + }, + 'fpgm': { + 'model_name': 'naive', + 'device': 'cpu', + 'input_shape': [64, 1, 28, 28], + 'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth' + }, + 'slim': { + 'model_name': 'vgg19', + 'device': 'cuda', + 'input_shape': [64, 3, 32, 32], + 'masks_file': 'mask_vgg19_cifar10.pth' + } +} -def slim_speedup(masks_file, model_checkpoint): - device = torch.device('cuda') - model = VGG(depth=19) +def model_inference(config): + masks_file = config['masks_file'] + device = torch.device(config['device']) + if config['model_name'] == 'vgg16': + model = VGG(depth=16) + elif config['model_name'] == 'vgg19': + model = VGG(depth=19) + elif config['model_name'] == 'naive': + from model_prune_torch import NaiveModel + model = NaiveModel() model.to(device) model.eval() - dummy_input = torch.randn(64, 3, 32, 32) + dummy_input = torch.randn(config['input_shape']).to(device) + use_mask_out = use_speedup_out = None + # must run use_mask before use_speedup because use_speedup modify the model if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) + apply_compression_results(model, masks_file, 'cpu' if config['device'] == 'cpu' else None) start = time.time() for _ in range(32): - out = model(dummy_input) - print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - #print("model before: ", model) - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) + use_mask_out = model(dummy_input) + #print(out.size(), out) + print('elapsed time when use mask: ', time.time() - start) + if use_speedup: + m_speedup = ModelSpeedup(model, dummy_input, masks_file) m_speedup.speedup_model() - #print("model after: ", model) - dummy_input = dummy_input.to(device) start = time.time() for _ in range(32): - out = model(dummy_input) - print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return + use_speedup_out = model(dummy_input) + #print(out.size(), out) + print('elapsed time when use speedup: ', time.time() - start) + if compare_results: + if torch.allclose(use_mask_out, use_speedup_out): + print('the outputs from use_mask and use_speedup are the same') + else: + print('ERROR: the outputs from use_mask and use_speedup are different') if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") parser.add_argument("--example_name", type=str, default="apoz", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") - parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model") args = parser.parse_args() - if args.example_name == 'slim': - if args.masks_file is None: - args.masks_file = 'mask_vgg19_cifar10.pth' - slim_speedup(args.masks_file, args.model_checkpoint) - elif args.example_name == 'fpgm': - if args.masks_file is None: - args.masks_file = './checkpoints/mask_naive_mnist_fpgm.pth' - fpgm_speedup(args.masks_file, args.model_checkpoint) - elif args.example_name == 'l1filter': - if args.masks_file is None: - args.masks_file = './checkpoints/mask_vgg16_cifar10_l1.pth' - l1filter_speedup(args.masks_file, args.model_checkpoint) - elif args.example_name == 'apoz': - if args.masks_file is None: - args.masks_file = './checkpoints/mask_vgg16_cifar10_apoz.pth' - apoz_speedup(args.masks_file, args.model_checkpoint) - else: - raise ValueError('unsupported example_name: {}'.format(args.example_name)) + if args.masks_file is not None: + config[args.example_name]['masks_file'] = args.masks_file + model_inference(config[args.example_name]) diff --git a/src/sdk/pynni/nni/compression/torch/apply_compression.py b/src/sdk/pynni/nni/compression/torch/apply_compression.py index f9e0b7c3c1..3831896108 100644 --- a/src/sdk/pynni/nni/compression/torch/apply_compression.py +++ b/src/sdk/pynni/nni/compression/torch/apply_compression.py @@ -6,9 +6,11 @@ logger = logging.getLogger('torch apply compression') -def apply_compression_results(model, masks_file): +def apply_compression_results(model, masks_file, map_location=None): """ Apply the masks from ```masks_file``` to the model + Note: this API is for inference, because it simply multiplies weights with + corresponding masks when this API is called. Parameters ---------- @@ -17,7 +19,7 @@ def apply_compression_results(model, masks_file): masks_file : str The path of the mask file """ - masks = torch.load(masks_file) + masks = torch.load(masks_file, map_location) for name, module in model.named_modules(): if name in masks: module.weight.data = module.weight.data.mul_(masks[name]['weight']) From 6daf7ae48d07d5fad5ba9a6a11114cfa2fc5b541 Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 18 Mar 2020 22:07:08 +0800 Subject: [PATCH 6/9] update --- examples/model_compress/model_speedup.py | 23 +++++++++++-------- .../compression/speedup/torch/compressor.py | 6 +++-- .../compression/torch/apply_compression.py | 2 ++ 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index efdb685ed9..2e7629ca2b 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -61,28 +61,33 @@ def model_inference(config): start = time.time() for _ in range(32): use_mask_out = model(dummy_input) - #print(out.size(), out) print('elapsed time when use mask: ', time.time() - start) if use_speedup: - m_speedup = ModelSpeedup(model, dummy_input, masks_file) + m_speedup = ModelSpeedup(model, dummy_input, masks_file, + 'cpu' if config['device'] == 'cpu' else None) m_speedup.speedup_model() start = time.time() for _ in range(32): use_speedup_out = model(dummy_input) - #print(out.size(), out) print('elapsed time when use speedup: ', time.time() - start) if compare_results: - if torch.allclose(use_mask_out, use_speedup_out): + if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07): print('the outputs from use_mask and use_speedup are the same') else: - print('ERROR: the outputs from use_mask and use_speedup are different') + raise RuntimeError('the outputs from use_mask and use_speedup are different') if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") - parser.add_argument("--example_name", type=str, default="apoz", help="the name of pruning example") + parser.add_argument("--example_name", type=str, default="fpgm", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") args = parser.parse_args() - if args.masks_file is not None: - config[args.example_name]['masks_file'] = args.masks_file - model_inference(config[args.example_name]) + if args.example_name != 'all': + if args.masks_file is not None: + config[args.example_name]['masks_file'] = args.masks_file + model_inference(config[args.example_name]) + else: + model_inference(config['fpgm']) + model_inference(config['slim']) + model_inference(config['l1filter']) + model_inference(config['apoz']) diff --git a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py index b6a9c94e40..c594aa2ff8 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py @@ -70,7 +70,7 @@ class ModelSpeedup: This class is to speedup the model with provided weight mask """ - def __init__(self, model, dummy_input, masks_file): + def __init__(self, model, dummy_input, masks_file, map_location=None): """ Parameters ---------- @@ -80,10 +80,12 @@ def __init__(self, model, dummy_input, masks_file): The dummy input for ```jit.trace```, users should put it on right device before pass in masks_file : str The path of user provided mask file + map_location : str + the device on which masks are placed, same to map_location in ```torch.load``` """ self.bound_model = model self.dummy_input = dummy_input - self.masks = torch.load(masks_file) + self.masks = torch.load(masks_file, map_location) self.is_training = model.training # to obtain forward graph, model should be in ```eval``` mode if self.is_training: diff --git a/src/sdk/pynni/nni/compression/torch/apply_compression.py b/src/sdk/pynni/nni/compression/torch/apply_compression.py index 3831896108..315a8579b7 100644 --- a/src/sdk/pynni/nni/compression/torch/apply_compression.py +++ b/src/sdk/pynni/nni/compression/torch/apply_compression.py @@ -18,6 +18,8 @@ def apply_compression_results(model, masks_file, map_location=None): The model to be compressed masks_file : str The path of the mask file + map_location : str + the device on which masks are placed, same to map_location in ```torch.load``` """ masks = torch.load(masks_file, map_location) for name, module in model.named_modules(): From 27b5d0d57b2dd7723f41bffbcec84659259df7e2 Mon Sep 17 00:00:00 2001 From: quzha Date: Wed, 18 Mar 2020 22:18:18 +0800 Subject: [PATCH 7/9] update --- examples/model_compress/model_speedup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index 2e7629ca2b..307cea16e4 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -36,7 +36,7 @@ 'model_name': 'vgg19', 'device': 'cuda', 'input_shape': [64, 3, 32, 32], - 'masks_file': 'mask_vgg19_cifar10.pth' + 'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth' } } @@ -78,10 +78,10 @@ def model_inference(config): if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") - parser.add_argument("--example_name", type=str, default="fpgm", help="the name of pruning example") + parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") args = parser.parse_args() - + if args.example_name != 'all': if args.masks_file is not None: config[args.example_name]['masks_file'] = args.masks_file From d6996d2466982279deb568bf53b554b6bab7069e Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 23 Mar 2020 09:12:20 +0800 Subject: [PATCH 8/9] update --- examples/model_compress/model_speedup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index 307cea16e4..f427c63d8e 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -1,3 +1,4 @@ +import os import argparse import time import torch @@ -85,6 +86,10 @@ def model_inference(config): if args.example_name != 'all': if args.masks_file is not None: config[args.example_name]['masks_file'] = args.masks_file + if not os.path.exists(config[args.example_name]['masks_file']): + msg = '{} does not exist! You should specify masks_file correctly, ' + msg += 'or use default one which is generated by model_prune_torch.py' + raise RuntimeError(msg.format(config[args.example_name]['masks_file'])) model_inference(config[args.example_name]) else: model_inference(config['fpgm']) From 48be5519a458447a1a31d495699270c7ea5a46a4 Mon Sep 17 00:00:00 2001 From: quzha Date: Mon, 23 Mar 2020 10:23:11 +0800 Subject: [PATCH 9/9] update --- examples/model_compress/model_speedup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index f427c63d8e..f6ada91d96 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -87,8 +87,8 @@ def model_inference(config): if args.masks_file is not None: config[args.example_name]['masks_file'] = args.masks_file if not os.path.exists(config[args.example_name]['masks_file']): - msg = '{} does not exist! You should specify masks_file correctly, ' - msg += 'or use default one which is generated by model_prune_torch.py' + msg = '{} does not exist! You should specify masks_file correctly, ' \ + 'or use default one which is generated by model_prune_torch.py' raise RuntimeError(msg.format(config[args.example_name]['masks_file'])) model_inference(config[args.example_name]) else: