diff --git a/examples/model_compress/model_prune_torch.py b/examples/model_compress/model_prune_torch.py index fb64ecf351..b03be0287c 100644 --- a/examples/model_compress/model_prune_torch.py +++ b/examples/model_compress/model_prune_torch.py @@ -127,7 +127,7 @@ def forward(self, x): x = F.max_pool2d(x, 2, 2) x = F.relu(self.bn2(self.conv2(x))) x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4 * 4 * 50) + x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x diff --git a/examples/model_compress/model_speedup.py b/examples/model_compress/model_speedup.py index 9d27d98da9..f6ada91d96 100644 --- a/examples/model_compress/model_speedup.py +++ b/examples/model_compress/model_speedup.py @@ -1,3 +1,4 @@ +import os import argparse import time import torch @@ -9,145 +10,89 @@ from nni.compression.torch import apply_compression_results torch.manual_seed(0) -use_mask = False +use_mask = True +use_speedup = True +compare_results = True -def apoz_speedup(masks_file, model_checkpoint): - device = torch.device('cuda') - model = VGG(depth=16) - model.to(device) - model.eval() - - dummy_input = torch.randn(64, 3, 32, 32) - if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - #print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - #print("model before: ", model) - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) - m_speedup.speedup_model() - #print("model after: ", model) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - #print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return +config = { + 'apoz': { + 'model_name': 'vgg16', + 'device': 'cuda', + 'input_shape': [64, 3, 32, 32], + 'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth' + }, + 'l1filter': { + 'model_name': 'vgg16', + 'device': 'cuda', + 'input_shape': [64, 3, 32, 32], + 'masks_file': './checkpoints/mask_vgg16_cifar10_l1.pth' + }, + 'fpgm': { + 'model_name': 'naive', + 'device': 'cpu', + 'input_shape': [64, 1, 28, 28], + 'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth' + }, + 'slim': { + 'model_name': 'vgg19', + 'device': 'cuda', + 'input_shape': [64, 3, 32, 32], + 'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth' + } +} -def l1filter_speedup(masks_file, model_checkpoint): - device = torch.device('cuda') - model = VGG(depth=16) +def model_inference(config): + masks_file = config['masks_file'] + device = torch.device(config['device']) + if config['model_name'] == 'vgg16': + model = VGG(depth=16) + elif config['model_name'] == 'vgg19': + model = VGG(depth=19) + elif config['model_name'] == 'naive': + from model_prune_torch import NaiveModel + model = NaiveModel() model.to(device) model.eval() - dummy_input = torch.randn(64, 3, 32, 32) + dummy_input = torch.randn(config['input_shape']).to(device) + use_mask_out = use_speedup_out = None + # must run use_mask before use_speedup because use_speedup modify the model if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) + apply_compression_results(model, masks_file, 'cpu' if config['device'] == 'cpu' else None) start = time.time() for _ in range(32): - out = model(dummy_input) - #print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - #print("model before: ", model) - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) + use_mask_out = model(dummy_input) + print('elapsed time when use mask: ', time.time() - start) + if use_speedup: + m_speedup = ModelSpeedup(model, dummy_input, masks_file, + 'cpu' if config['device'] == 'cpu' else None) m_speedup.speedup_model() - #print("model after: ", model) - dummy_input = dummy_input.to(device) start = time.time() for _ in range(32): - out = model(dummy_input) - #print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return - -def fpgm_speedup(masks_file, model_checkpoint): - from fpgm_torch_mnist import Mnist - device = torch.device('cpu') - model = Mnist() - model.to(device) - model.print_conv_filter_sparsity() - - dummy_input = torch.randn(64, 1, 28, 28) - if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(40): - out = model(dummy_input) - print('mask elapsed time: ', time.time() - start) - #print(out.size(), out) - return - else: - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) - m_speedup.speedup_model() - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(40): - out = model(dummy_input) - print('speedup elapsed time: ', time.time() - start) - #print(out.size(), out) - return - -def slim_speedup(masks_file, model_checkpoint): - device = torch.device('cuda') - model = VGG(depth=19) - model.to(device) - model.eval() - - dummy_input = torch.randn(64, 3, 32, 32) - if use_mask: - apply_compression_results(model, masks_file) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - #print(out.size(), out) - print('mask elapsed time: ', time.time() - start) - return - else: - #print("model before: ", model) - m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) - m_speedup.speedup_model() - #print("model after: ", model) - dummy_input = dummy_input.to(device) - start = time.time() - for _ in range(32): - out = model(dummy_input) - #print(out.size(), out) - print('speedup elapsed time: ', time.time() - start) - return + use_speedup_out = model(dummy_input) + print('elapsed time when use speedup: ', time.time() - start) + if compare_results: + if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07): + print('the outputs from use_mask and use_speedup are the same') + else: + raise RuntimeError('the outputs from use_mask and use_speedup are different') if __name__ == '__main__': parser = argparse.ArgumentParser("speedup") parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") - parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model") args = parser.parse_args() - - if args.example_name == 'slim': - if args.masks_file is None: - args.masks_file = 'mask_vgg19_cifar10.pth' - slim_speedup(args.masks_file, args.model_checkpoint) - elif args.example_name == 'fpgm': - if args.masks_file is None: - args.masks_file = 'mask.pth' - fpgm_speedup(args.masks_file, args.model_checkpoint) - elif args.example_name == 'l1filter': - if args.masks_file is None: - args.masks_file = 'mask_vgg16_cifar10.pth' - l1filter_speedup(args.masks_file, args.model_checkpoint) - elif args.example_name == 'apoz': - if args.masks_file is None: - args.masks_file = 'mask_vgg16_cifar10.pth' - apoz_speedup(args.masks_file, args.model_checkpoint) + + if args.example_name != 'all': + if args.masks_file is not None: + config[args.example_name]['masks_file'] = args.masks_file + if not os.path.exists(config[args.example_name]['masks_file']): + msg = '{} does not exist! You should specify masks_file correctly, ' \ + 'or use default one which is generated by model_prune_torch.py' + raise RuntimeError(msg.format(config[args.example_name]['masks_file'])) + model_inference(config[args.example_name]) else: - raise ValueError('unsupported example_name: {}'.format(args.example_name)) + model_inference(config['fpgm']) + model_inference(config['slim']) + model_inference(config['l1filter']) + model_inference(config['apoz']) diff --git a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py index b6a9c94e40..c594aa2ff8 100644 --- a/src/sdk/pynni/nni/compression/speedup/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/speedup/torch/compressor.py @@ -70,7 +70,7 @@ class ModelSpeedup: This class is to speedup the model with provided weight mask """ - def __init__(self, model, dummy_input, masks_file): + def __init__(self, model, dummy_input, masks_file, map_location=None): """ Parameters ---------- @@ -80,10 +80,12 @@ def __init__(self, model, dummy_input, masks_file): The dummy input for ```jit.trace```, users should put it on right device before pass in masks_file : str The path of user provided mask file + map_location : str + the device on which masks are placed, same to map_location in ```torch.load``` """ self.bound_model = model self.dummy_input = dummy_input - self.masks = torch.load(masks_file) + self.masks = torch.load(masks_file, map_location) self.is_training = model.training # to obtain forward graph, model should be in ```eval``` mode if self.is_training: diff --git a/src/sdk/pynni/nni/compression/torch/apply_compression.py b/src/sdk/pynni/nni/compression/torch/apply_compression.py index 2531da5039..315a8579b7 100644 --- a/src/sdk/pynni/nni/compression/torch/apply_compression.py +++ b/src/sdk/pynni/nni/compression/torch/apply_compression.py @@ -3,13 +3,14 @@ import logging import torch -from .compressor import Pruner logger = logging.getLogger('torch apply compression') -def apply_compression_results(model, masks_file): +def apply_compression_results(model, masks_file, map_location=None): """ Apply the masks from ```masks_file``` to the model + Note: this API is for inference, because it simply multiplies weights with + corresponding masks when this API is called. Parameters ---------- @@ -17,54 +18,12 @@ def apply_compression_results(model, masks_file): The model to be compressed masks_file : str The path of the mask file + map_location : str + the device on which masks are placed, same to map_location in ```torch.load``` """ - apply_comp = ApplyCompression(model, masks_file) - apply_comp.compress() - -class ApplyCompression(Pruner): - """ - This class is not to generate masks, but applying existing masks - """ - - def __init__(self, model, masks_file): - """ - Parameters - ---------- - model : torch.nn.module - Model to be masked - masks_file : str - The path of user provided mask file - """ - self.bound_model = model - self.masks = torch.load(masks_file) - for module_name in self.masks: - print('module_name: ', module_name) - config_list = self._build_config() - super().__init__(model, config_list) - - def _build_config(self): - op_names = [] - for module_name in self.masks: - op_names.append(module_name) - return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}] - - def calc_mask(self, layer, config, **kwargs): - """ - Directly return the corresponding mask - - Parameters - ---------- - layer : LayerInfo - The layer to be pruned - config : dict - Pruning configurations for this weight - kwargs : dict - Auxiliary information - - Returns - ------- - dict - Mask of the layer - """ - assert layer.name in self.masks - return self.masks[layer.name] + masks = torch.load(masks_file, map_location) + for name, module in model.named_modules(): + if name in masks: + module.weight.data = module.weight.data.mul_(masks[name]['weight']) + if hasattr(module, 'bias') and module.bias is not None and 'bias' in masks[name]: + module.bias.data = module.bias.data.mul_(masks[name]['bias']) \ No newline at end of file