Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix pruners for DataParallel support #2003

Merged
merged 3 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner

class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
Expand All @@ -27,8 +28,14 @@ def _get_conv_weight_sparsity(self, conv_layer):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters

def print_conv_filter_sparsity(self):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)

print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))

Expand Down
3 changes: 1 addition & 2 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def forward(self, *inputs):
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
Expand Down Expand Up @@ -565,4 +565,3 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False

48 changes: 26 additions & 22 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,42 +84,47 @@ def __init__(self, model, config_list):

super().__init__(model, config_list)
self.now_epoch = 0
self.if_init_list = {}
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable

def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
dictionary for storing masks
"""

weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})

if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None

mask = {'weight': torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable

return new_mask

def compute_target_sparsity(self, config):
Expand Down Expand Up @@ -165,9 +170,8 @@ def update_epoch(self, epoch):

if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True

for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one is better? wrapper.registered_buffers['if_calculated'] or wrapper.if_calculated

Copy link
Contributor Author

@chicm-ms chicm-ms Feb 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested, wrapper.if_calculated does not work, there is no error reported, but the value is still 1 if we use wrapper.if_calculated.copy_(torch.tensor(0))


class SlimPruner(Pruner):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model, config_list):
"""

super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable

def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
Expand Down Expand Up @@ -69,7 +69,7 @@ def calc_mask(self, layer, config, **kwargs):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask


Expand Down Expand Up @@ -257,4 +257,5 @@ def _get_distance_sum(self, weight, in_idx, out_idx):
return x.sum()

def update_epoch(self, epoch):
self.mask_calculated_ops = set()
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
2 changes: 0 additions & 2 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def test_torch_fpgm_pruner(self):
masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))

pruner.update_epoch(1)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
Expand All @@ -159,7 +158,6 @@ def test_tf_fpgm_pruner(self):

assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))

pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
Expand Down