Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix builtin pruners bug #1612

Merged
merged 2 commits into from
Oct 16, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import torch
from .compressor import Pruner

__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ]
__all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner']

logger = logging.getLogger('torch pruner')


class LevelPruner(Pruner):
"""Prune to an exact pruning level specification
"""

def __init__(self, config_list):
"""
config_list: supported keys:
Expand All @@ -21,9 +22,9 @@ def calc_mask(self, weight, config, **kwargs):
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape)
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
return torch.gt(w_abs, threshold).type(weight.type())
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
return torch.gt(w_abs, threshold).type_as(weight)


class AGP_Pruner(Pruner):
Expand All @@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""

def __init__(self, config_list):
"""
config_list: supported keys:
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- start_epoch: start epoch number begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2
"""
Expand All @@ -49,15 +51,15 @@ def __init__(self, config_list):
self.now_epoch = 1

def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs()*mask
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask

Expand All @@ -74,45 +76,45 @@ def compute_target_sparsity(self, config):
if end_epoch <= self.now_epoch:
return final_sparsity

span = ((end_epoch - start_epoch-1)//freq)*freq
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)*
(1.0 - ((self.now_epoch - start_epoch)/span))**3)
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity

def update_epoch(self, epoch):
if epoch > 0:
self.now_epoch = epoch


class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf

I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""

def __init__(self, config_list):
"""
config_list: supported keys:
- sparsity: chosen pruning sparsity
"""
super().__init__(config_list)
self.mask_list = {}



def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
# if we want to generate new mask, we should update weigth first
weight = weight*mask
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
# if we want to generate new mask, we should update weight first
weight = weight * mask
target_sparsity = config['sparsity'] * torch.std(weight).item()
k = int(weight.numel() * target_sparsity)
if k == 0:
return mask

w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask