Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] Add global sort for taylor pruner #3896

Merged
merged 11 commits into from
Jul 20, 2021
Merged
2 changes: 2 additions & 0 deletions docs/en_US/Compression/Pruner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based

We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference `dependency-aware <./DependencyAware.rst>`__ for more details.

What's more, we provide a global-sort mode for this pruner which is aligned with paper implementation. Please set parameter 'global_sort' to True when instantiate TaylorFOWeightFilterPruner.

Usage
^^^^^

Expand Down
8 changes: 7 additions & 1 deletion examples/model_compress/pruning/basic_pruners_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def trainer(model, optimizer, criterion, epoch):
}]

else:
if args.global_sort:
print('Enable the global_sort mode')
# only taylor pruner supports global sort mode currently
kw_args['global_sort'] = True
if args.dependency_aware:
dummy_input = get_dummy_input(args, device)
print('Enable the dependency_aware mode')
Expand Down Expand Up @@ -331,6 +335,8 @@ def trainer(model, optimizer, criterion, epoch):
help='target overall target sparsity')
parser.add_argument('--dependency-aware', action='store_true', default=False,
help='toggle dependency aware mode')
parser.add_argument('--global-sort', action='store_true', default=False,
help='toggle global sort mode')
parser.add_argument('--pruner', type=str, default='l1filter',
choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
Expand All @@ -356,4 +362,4 @@ def trainer(model, optimizer, criterion, epoch):
args.pruner = params['pruner']
args.model = params['model']

main(args)
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,20 @@ class TaylorFOWeightFilterPruner(IterativePruner):
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
global_sort: bool
Only support TaylorFOWeightFilterPruner currently.
If prune the model in a global-sort way. If it is `True`, this pruner will prune
the model according to the global contributions information which means channel contributions
will be sorted globally and whether specific channel will be pruned depends on global information.
"""

def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1,
dependency_aware=False, dummy_input=None):
dependency_aware=False, dummy_input=None, global_sort=False):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer,
criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1, dependency_aware=dependency_aware,
dummy_input=dummy_input)
self.masker.global_sort = global_sort

def _supported_dependency_aware(self):
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker):

"""

def __init__(self, model, pruner, preserve_round=1, dependency_aware=False):
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be better to implement it for Taylor alone without changing the underlying interface, because we don't have another masker use these interfaces, and slim is natural using global sort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But based our discussion before, we agree that we should modify the condition in

https://github.com/microsoft/nni/blob/master/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py#L63

May be we can add assert to tell user that we only support Taylor currently?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the current implementation is good for me, just a little opinion, don't need to modify. And we need update doc for Taylor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated doc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you put global sort in StructuredWeightMasker but not in TaylorFOWeightFilterPrunerMasker. what is the reason for it? is it because this would make TaylorFOWeightFilterPrunerMasker have different initial arguments with other structured weight masker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in here, we modify the conditional judgement statement which discriminate different situations including global-sort and dependency-aware. The reason why we do conditional judgement here is because we think global-sort should be the same level with dependency-aware. This conditional judgement is done in StructuredWeightMasker so we put global_sort here.

self.model = model
self.pruner = pruner
self.preserve_round = preserve_round
self.dependency_aware = dependency_aware
self.global_sort = global_sort

def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs):
"""
Expand All @@ -60,7 +61,11 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs):
depen_kwargs: dict
The kw_args for the dependency-aware mode.
"""
if not self.dependency_aware:
if self.global_sort:
# if the global_sort switch is on, calculate the mask based
# on global model information
return self._global_calc_mask(sparsity, wrapper, wrapper_idx)
elif not self.dependency_aware:
# calculate the mask in the normal way, each layer calculate its
# own mask separately
return self._normal_calc_mask(sparsity, wrapper, wrapper_idx)
Expand Down Expand Up @@ -127,6 +132,12 @@ def _get_current_state(self, sparsity, wrapper, wrapper_idx=None):
# weight*mask_weight: apply base mask for iterative pruning
return mask, weight * mask_weight, num_prune

def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
num_prune = self._get_global_num_prune(wrapper, wrapper_idx)
mask, weight, _ = self._get_current_state(
sparsity, wrapper, wrapper_idx)
return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx)

def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Expand Down Expand Up @@ -477,6 +488,31 @@ def __init__(self, model, pruner, statistics_batch_num=1):
self.pruner.iterations = 0
self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.patch_optimizer(self.calc_contributions)
self.global_threshold = None

def _get_global_threshold(self):
channel_contribution_list = []
for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()):
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
wrapper_size = wrapper.module.weight.size().numel()
channel_size = wrapper.module.weight.size(0)
contribution_expand = channel_contribution.expand(int(wrapper_size / channel_size), channel_size).reshape(-1)
channel_contribution_list.append(contribution_expand)
all_channel_contributions = torch.cat(channel_contribution_list)
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the filters' sizes are different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's truly a key problem, I will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we need view(-1) to the contribution, or if they can cat together with different size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a problem. Have fixed.

self.global_threshold = torch.topk(
all_channel_contributions.view(-1), k, largest=False)[0].max()

def _get_global_num_prune(self, wrapper, wrapper_idx):
if self.global_threshold is None:
self._get_global_threshold()
weight = wrapper.module.weight.data
filters = weight.size(0)
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a dumb question, do we want < or <= here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a key point and I thought it before, we can choose using '<=' which may cause num_prune larger than specific sparsity or < which may cause num_prune smaller than specific sparsity. And I finally choose to use < since little sparsity is much more safe and iterative pruning process would help it do further pruning.

if num_prune == filters:
num_prune -= 1
return num_prune

def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
Expand Down
44 changes: 44 additions & 0 deletions test/ut/sdk/test_compressor_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,50 @@ def test_torch_taylorFOweight_pruner(self):
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

def test_torch_taylorFOweight_pruner_global_sort(self):
"""
After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all
of the conv operators. Then it will prune low contribution filters depends on the global information.

So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together,
this can be verified through:
`all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))`
`all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))`
"""

w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}]

model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True)

x = torch.rand((1, 1, 28, 28), requires_grad=True)
model.conv1.module.weight.data = torch.tensor(w1).float()
model.conv2.module.weight.data = torch.tensor(w2).float()

y = model(x)
y.backward(torch.ones_like(y))

model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
optimizer.step()

mask1 = pruner.calc_mask(model.conv1)
mask2 = pruner.calc_mask(model.conv2)
print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy())
print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy())
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))

def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
Expand Down