Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
refactor compression sdk (#1562)
Browse files Browse the repository at this point in the history
* refactor compression sdk

* bugfix

* bugfix

* update ut
  • Loading branch information
liuzhe-lz authored Sep 26, 2019
1 parent cc5af7e commit 70cf8f7
Show file tree
Hide file tree
Showing 23 changed files with 584 additions and 785 deletions.
2 changes: 1 addition & 1 deletion examples/model_compress/configure_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ AGPruner:
frequency: 1
initial_sparsity: 0.05
final_sparsity: 0.8
support_type: 'default'
op_type: 'default'
6 changes: 3 additions & 3 deletions examples/model_compress/main_tf_pruner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from nni.compressors.tf_compressor import AGPruner
from nni.compression.tensorflow import AGP_Pruner
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

Expand Down Expand Up @@ -88,9 +88,9 @@ def main():
'start_epoch': 1,
'end_epoch': 10,
'frequency': 1,
'support_type': 'default'
'op_type': 'default'
}]
pruner = AGPruner(configure_list)
pruner = AGP_Pruner(configure_list)
# if you want to load from yaml file
# configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner')
# configure_list = configure_file.get('config',[])
Expand Down
6 changes: 3 additions & 3 deletions examples/model_compress/main_tf_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from nni.compressors.tf_compressor import QATquantizer
from nni.compression.tensorflow import QAT_Quantizer
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

Expand Down Expand Up @@ -82,8 +82,8 @@ def main():
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(tf.get_default_graph())
'''
configure_list = [{'q_bits':8, 'support_type':'default'}]
quantizer = QATquantizer(configure_list)
configure_list = [{'q_bits':8, 'op_type':'default'}]
quantizer = QAT_Quantizer(configure_list)
quantizer(tf.get_default_graph())
# you can also use compress(model) or compress_default_graph()
# method like QATquantizer(q_bits = 8).compress_default_graph()
Expand Down
7 changes: 3 additions & 4 deletions examples/model_compress/main_torch_pruner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from nni.compressors.torch_compressor import AGPruner
from nni.compression.torch import AGP_Pruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
Expand Down Expand Up @@ -74,11 +74,10 @@ def main():
'start_epoch': 1,
'end_epoch': 10,
'frequency': 1,
'support_type': 'default'
'op_type': 'default'
}]

pruner = AGPruner(configure_list)
pruner.load_configure('configure_example.yaml')
pruner = AGP_Pruner(configure_list)
pruner(model)
# you can also use compress(model) method
# like that pruner.compress(model)
Expand Down
6 changes: 3 additions & 3 deletions examples/model_compress/main_torch_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from nni.compressors.torch_compressor import QATquantizer
from nni.compression.torch import QAT_Quantizer
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
Expand Down Expand Up @@ -68,8 +68,8 @@ def main():
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list = [{'q_bits':8, 'support_type':'default'}]
quantizer = QATquantizer(configure_list)
configure_list = [{'q_bits':8, 'op_type':'default'}]
quantizer = QAT_Quantizer(configure_list)
quantizer(model)
# you can also use compress(model) method
# like thaht quantizer.compress(model)
Expand Down
3 changes: 3 additions & 0 deletions src/sdk/pynni/nni/compression/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
import logging
import tensorflow as tf
from ._nnimc_tf import TfPruner
from ._nnimc_tf import _tf_default_get_configure, _tf_default_load_configure_file
from .compressor import Pruner

__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ]

_logger = logging.getLogger(__name__)

import logging
logger = logging.getLogger('tensorflow pruner')

class LevelPruner(TfPruner):
def __init__(self, configure_list):
class LevelPruner(Pruner):
def __init__(self, config_list):
"""
Configure Args:
sparsity
"""
super().__init__()
self.configure_list = []
if isinstance(configure_list, list):
for configure in configure_list:
self.configure_list.append(configure)
else:
raise ValueError('please init with configure list')


def get_sparsity(self, configure={}):
sparsity = configure.get('sparsity', 0)
return sparsity

def calc_mask(self, layer_info, weight):
sparsity = self.get_sparsity(_tf_default_get_configure(self.configure_list, layer_info))
super().__init__(config_list)

threshold = tf.contrib.distributions.percentile(tf.abs(weight), sparsity * 100)
def calc_mask(self, layer, weight, config):
threshold = tf.contrib.distributions.percentile(tf.abs(weight), config['sparsity'] * 100)
return tf.cast(tf.math.greater(tf.abs(weight), threshold), weight.dtype)

class AGPruner(TfPruner):

class AGP_Pruner(Pruner):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Expand All @@ -40,7 +30,7 @@ class AGPruner(TfPruner):
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, configure_list):
def __init__(self, config_list):
"""
Configure Args
initial_sparsity:
Expand All @@ -49,27 +39,27 @@ def __init__(self, configure_list):
end_epoch: end epoch number stop update mask
frequency: if you want update every 2 epoch, you can set it 2
"""
super().__init__()
self.configure_list = []
if isinstance(configure_list, list):
for configure in configure_list:
self.configure_list.append(configure)
else:
raise ValueError('please init with configure list')

super().__init__(config_list)
self.now_epoch = tf.Variable(0)
self.assign_handler = []

def compute_target_sparsity(self, layer_info):
configure = _tf_default_get_configure(self.configure_list, layer_info)
end_epoch = configure.get('end_epoch', 1)
start_epoch = configure.get('start_epoch', 0)
freq = configure.get('frequency', 1)
final_sparsity = configure.get('final_sparsity', 0)
initial_sparsity = configure.get('initial_sparsity', 0)

def calc_mask(self, layer, weight, config):
target_sparsity = self.compute_target_sparsity(config)
threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100)
# stop gradient in case gradient change the mask
mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype))
self.assign_handler.append(tf.assign(weight, weight * mask))
return mask

def compute_target_sparsity(self, config):
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)

if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
_logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity

now_epoch = tf.minimum(self.now_epoch, tf.constant(end_epoch))
Expand All @@ -80,58 +70,34 @@ def compute_target_sparsity(self, layer_info):
(initial_sparsity - final_sparsity)*
(tf.pow(1.0 - base, 3)))
return target_sparsity


def calc_mask(self, layer_info, weight):

target_sparsity = self.compute_target_sparsity(layer_info)
threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100)
# stop gradient in case gradient change the mask
mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype))
self.assign_handler.append(tf.assign(weight, weight*mask))
return mask

def update_epoch(self, epoch, sess):
sess.run(self.assign_handler)
sess.run(tf.assign(self.now_epoch, int(epoch)))


class SensitivityPruner(TfPruner):
class SensitivityPruner(Pruner):
"""
Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""
def __init__(self, configure_list):
def __init__(self, config_list):
"""
Configure Args:
sparsity: chosen pruning sparsity
"""
super().__init__()
self.configure_list = []
if isinstance(configure_list, list):
for configure in configure_list:
self.configure_list.append(configure)
else:
raise ValueError('please init with configure list')

super().__init__(config_list)
self.layer_mask = {}
self.assign_handler = []

def get_sparsity(self, configure={}):
sparsity = configure.get('sparsity', 0)
return sparsity
def calc_mask(self, layer, weight, config):
target_sparsity = config['sparsity'] * tf.math.reduce_std(weight)
mask = tf.get_variable(layer.name + '_mask', initializer=tf.ones(weight.shape), trainable=False)
self.layer_mask[layer.name] = mask

def calc_mask(self, layer_info, weight):
sparsity = self.get_sparsity(_tf_default_get_configure(self.configure_list, layer_info))

target_sparsity = sparsity * tf.math.reduce_std(weight)
mask = tf.get_variable(layer_info.name+'_mask',initializer=tf.ones(weight.shape), trainable=False)
self.layer_mask[layer_info.name] = mask

weight_assign_handler = tf.assign(weight, mask*weight)
# use control_dependencies so that weight_assign_handler will be executed before mask_update_handler
with tf.control_dependencies([weight_assign_handler]):
Expand Down
74 changes: 74 additions & 0 deletions src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import tensorflow as tf
from .compressor import Quantizer

__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ]

_logger = logging.getLogger(__name__)


class NaiveQuantizer(Quantizer):
"""
quantize weight to 8 bits
"""
def __init__(self, config_list):
super().__init__(config_list)
self.layer_scale = { }

def quantize_weight(self, layer, weight, config):
new_scale = tf.reduce_max(tf.abs(weight)) / 127
scale = tf.maximum(self.layer_scale.get(layer.name, tf.constant(0.0)), new_scale)
self.layer_scale[layer.name] = scale
orig_type = weight.dtype
return tf.cast(tf.cast(weight / scale, tf.int8), orig_type) * scale


class QAT_Quantizer(Quantizer):
"""
Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, config_list):
"""
Configure Args:
q_bits
"""
super().__init__(config_list)

def quantize_weight(self, layer, weight, config):
a = tf.stop_gradient(tf.reduce_min(weight))
b = tf.stop_gradient(tf.reduce_max(weight))
n = tf.cast(2 ** config['q_bits'], tf.float32)
scale = b-a/(n-1)

# use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round((weight-a)/scale)*scale +a

return qw


class DoReFaQuantizer(Quantizer):
"""
Quantizer using the DoReFa scheme, as defined in:
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, config_list):
"""
Configure Args:
q_bits
"""
super().__init__(config_list)

def quantize_weight(self, layer, weight, config):
a = tf.math.tanh(weight)
b = a/(2*tf.reduce_max(tf.abs(weight))) + 0.5

scale = pow(2, config['q_bits'] - 1)
# use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round(b*scale)/scale
r_qw = 2 * qw - 1
return r_qw
Loading

0 comments on commit 70cf8f7

Please sign in to comment.