diff --git a/examples/model_compress/main_torch_pruner.py b/examples/model_compress/main_torch_pruner.py index 10f7355cd1..17a6b131fe 100644 --- a/examples/model_compress/main_torch_pruner.py +++ b/examples/model_compress/main_torch_pruner.py @@ -55,7 +55,7 @@ def test(model, device, test_loader): def main(): torch.manual_seed(0) - device = torch.device('cpu') + device = torch.device('cuda') trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_loader = torch.utils.data.DataLoader( @@ -66,7 +66,7 @@ def main(): batch_size=1000, shuffle=True) model = Mnist() - model.to(device) + model = model.to(device) '''you can change this to LevelPruner to implement it pruner = LevelPruner(configure_list) @@ -82,14 +82,14 @@ def main(): pruner = AGP_Pruner(model, configure_list) model = pruner.compress() - + model = model.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for epoch in range(10): pruner.update_epoch(epoch) print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer) test(model, device, test_loader) - pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28]) + pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28], device) if __name__ == '__main__': diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 14cbc194f7..107db55a0b 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -226,7 +226,7 @@ def __init__(self, module, module_name, module_type, config, pruner): # config and pruner self.config = config self.pruner = pruner - self.registered_buffers = {} + self.registered_buffers = [] # register buffer for mask self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) @@ -234,16 +234,21 @@ def __init__(self, module, module_name, module_type, config, pruner): self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) else: self.register_buffer("bias_mask", None) - - self.registered_buffers['weight_mask'] = self.weight_mask - self.registered_buffers['bias_mask'] = self.bias_mask + self.registered_buffers.append('weight_mask') + self.registered_buffers.append('bias_mask') # register user specified buffer for name in self.pruner.buffers: self.register_buffer(name, self.pruner.buffers[name].clone()) - self.registered_buffers[name] = getattr(self, name) + self.registered_buffers.append(name) + + def get_registered_buffers(self): + buffers = {} + for name in self.registered_buffers: + buffers[name] = getattr(self, name) + return buffers def forward(self, *inputs): - mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers) + mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers()) if mask is not None: self.weight_mask.copy_(mask['weight']) # apply mask to weight @@ -399,6 +404,7 @@ def __init__(self, module, module_name, module_type, config, quantizer): # config and pruner self.config = config self.quantizer = quantizer + self.registered_buffers = [] # register buffer and parameter # old_weight is used to store origin weight and weight is used to store quantized weight @@ -413,10 +419,15 @@ def __init__(self, module, module_name, module_type, config, quantizer): self.module.register_buffer('weight', self.module.old_weight) # register user specified buffer - self.registered_buffers = {} for name in self.quantizer.buffers: self.register_buffer(name, self.quantizer.buffers[name].clone()) - self.registered_buffers[name] = getattr(self, name) + self.registered_buffers.append(name) + + def get_registered_buffers(self): + buffers = {} + for name in self.registered_buffers: + buffers[name] = getattr(self, name) + return buffers def forward(self, *inputs): if 'input' in self.config['quant_types']: @@ -426,7 +437,7 @@ def forward(self, *inputs): self.quantizer.quantize_input, self.config, LayerInfo(self.name, self.module), - **self.registered_buffers) + **self.get_registered_buffers()) if 'weight' in self.config['quant_types'] and _check_weight(self.module): new_weight = self.quantizer.quant_grad.apply( @@ -435,7 +446,7 @@ def forward(self, *inputs): self.quantizer.quantize_weight, self.config, LayerInfo(self.name, self.module), - **self.registered_buffers) + **self.get_registered_buffers()) self.module.weight = new_weight result = self.module(*inputs) else: @@ -448,7 +459,7 @@ def forward(self, *inputs): self.quantizer.quantize_output, self.config, LayerInfo(self.name, self.module), - **self.registered_buffers) + **self.get_registered_buffers()) return result class Quantizer(Compressor): diff --git a/src/sdk/pynni/nni/compression/torch/pruners.py b/src/sdk/pynni/nni/compression/torch/pruners.py index b49ed77f2f..f186fb6917 100644 --- a/src/sdk/pynni/nni/compression/torch/pruners.py +++ b/src/sdk/pynni/nni/compression/torch/pruners.py @@ -170,7 +170,7 @@ def update_epoch(self, epoch): if epoch > 0: self.now_epoch = epoch for wrapper in self.get_modules_wrapper(): - wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable + wrapper.if_calculated.copy_(torch.tensor(0)) # pylint: disable=not-callable class SlimPruner(Pruner): """