From 168d74e448d19aad13ad7cde670787248c827867 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 8 Oct 2019 19:36:50 +0800 Subject: [PATCH 1/3] Add example for customized advisor and some refactoring (#1569) Add example for customized advisor and some refactoring --- docs/en_US/Tuner/CustomizeAdvisor.md | 24 ++- docs/en_US/sdk_reference.rst | 3 + .../mnist_keras_customized_advisor/config.yml | 20 +++ .../dummy_advisor.py | 95 ++++++++++++ .../mnist_keras.py | 137 ++++++++++++++++++ .../search_space.json | 5 + .../hyperband_advisor/hyperband_advisor.py | 69 +++++---- src/sdk/pynni/nni/msg_dispatcher.py | 18 ++- src/sdk/pynni/nni/msg_dispatcher_base.py | 87 ++++++++++- src/sdk/pynni/tests/test_assessor.py | 6 +- src/sdk/pynni/tests/test_tuner.py | 14 +- 11 files changed, 408 insertions(+), 70 deletions(-) create mode 100644 examples/tuners/mnist_keras_customized_advisor/config.yml create mode 100644 examples/tuners/mnist_keras_customized_advisor/dummy_advisor.py create mode 100644 examples/tuners/mnist_keras_customized_advisor/mnist_keras.py create mode 100644 examples/tuners/mnist_keras_customized_advisor/search_space.json diff --git a/docs/en_US/Tuner/CustomizeAdvisor.md b/docs/en_US/Tuner/CustomizeAdvisor.md index 8dcb8330d4..aefdd959ad 100644 --- a/docs/en_US/Tuner/CustomizeAdvisor.md +++ b/docs/en_US/Tuner/CustomizeAdvisor.md @@ -1,16 +1,12 @@ # **How To** - Customize Your Own Advisor -*Advisor targets the scenario that the automl algorithm wants the methods of both tuner and assessor. Advisor is similar to tuner on that it receives trial parameters request, final results, and generate trial parameters. Also, it is similar to assessor on that it receives intermediate results, trial's end state, and could send trial kill command. Note that, if you use Advisor, tuner and assessor are not allowed to be used at the same time.* +*Warning: API is subject to change in future releases.* -So, if user want to implement a customized Advisor, she/he only need to: +Advisor targets the scenario that the automl algorithm wants the methods of both tuner and assessor. Advisor is similar to tuner on that it receives trial parameters request, final results, and generate trial parameters. Also, it is similar to assessor on that it receives intermediate results, trial's end state, and could send trial kill command. Note that, if you use Advisor, tuner and assessor are not allowed to be used at the same time. -1. Define an Advisor inheriting from the MsgDispatcherBase class -1. Implement the methods with prefix `handle_` except `handle_request` -1. Configure your customized Advisor in experiment YAML config file +If a user want to implement a customized Advisor, she/he only needs to: -Here is an example: - -**1) Define an Advisor inheriting from the MsgDispatcherBase class** +**1. Define an Advisor inheriting from the MsgDispatcherBase class.** For example: ```python from nni.msg_dispatcher_base import MsgDispatcherBase @@ -20,13 +16,11 @@ class CustomizedAdvisor(MsgDispatcherBase): ... ``` -**2) Implement the methods with prefix `handle_` except `handle_request`** - -Please refer to the implementation of Hyperband ([src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py](https://github.com/Microsoft/nni/tree/master/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py)) for how to implement the methods. +**2. Implement the methods with prefix `handle_` except `handle_request`**.. You might find [docs](https://nni.readthedocs.io/en/latest/sdk_reference.html#nni.msg_dispatcher_base.MsgDispatcherBase) for `MsgDispatcherBase` helpful. -**3) Configure your customized Advisor in experiment YAML config file** +**3. Configure your customized Advisor in experiment YAML config file.** -Similar to tuner and assessor. NNI needs to locate your customized Advisor class and instantiate the class, so you need to specify the location of the customized Advisor class and pass literal values as parameters to the \_\_init__ constructor. +Similar to tuner and assessor. NNI needs to locate your customized Advisor class and instantiate the class, so you need to specify the location of the customized Advisor class and pass literal values as parameters to the `__init__` constructor. ```yaml advisor: @@ -38,3 +32,7 @@ advisor: classArgs: arg1: value1 ``` + +## Example + +Here we provide an [example](../../../examples/tuners/mnist_keras_customized_advisor). diff --git a/docs/en_US/sdk_reference.rst b/docs/en_US/sdk_reference.rst index 5c3047ba37..64de1ee45e 100644 --- a/docs/en_US/sdk_reference.rst +++ b/docs/en_US/sdk_reference.rst @@ -50,6 +50,9 @@ Assessor Advisor ------------------------ +.. autoclass:: nni.msg_dispatcher_base.MsgDispatcherBase + :members: + .. autoclass:: nni.hyperband_advisor.hyperband_advisor.Hyperband :members: diff --git a/examples/tuners/mnist_keras_customized_advisor/config.yml b/examples/tuners/mnist_keras_customized_advisor/config.yml new file mode 100644 index 0000000000..0d8d987ac3 --- /dev/null +++ b/examples/tuners/mnist_keras_customized_advisor/config.yml @@ -0,0 +1,20 @@ +authorName: default +experimentName: example_customized_advisor +trialConcurrency: 4 +maxExecDuration: 1h +maxTrialNum: 200 +#choice: local, remote, pai +trainingServicePlatform: local +searchSpacePath: search_space.json +#choice: true, false +useAnnotation: false +advisor: + codeDir: . + classFileName: dummy_advisor.py + className: DummyAdvisor + classArgs: + k: 3 +trial: + command: python3 mnist_keras.py --epochs 100 --num_train 600 --num_test 100 + codeDir: . + gpuNum: 0 diff --git a/examples/tuners/mnist_keras_customized_advisor/dummy_advisor.py b/examples/tuners/mnist_keras_customized_advisor/dummy_advisor.py new file mode 100644 index 0000000000..5123b598fa --- /dev/null +++ b/examples/tuners/mnist_keras_customized_advisor/dummy_advisor.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import logging +from collections import defaultdict + +import json_tricks +import numpy as np +from nni import parameter_expressions as param +from nni.msg_dispatcher_base import MsgDispatcherBase +from nni.protocol import CommandType, send +from nni.utils import MetricType + +logger = logging.getLogger('customized_advisor') + + +class DummyAdvisor(MsgDispatcherBase): + """WARNING: Advisor API is subject to change in future releases. + + This advisor creates a new trial when validation accuracy of any one of the trials just dropped. + The trial is killed if the validation accuracy doesn't improve for at least k last-reported metrics. + To demonstrate the high flexibility of writing advisors, we don't use tuners or the standard definition of + search space. This is just a demo to customize an advisor. It's not intended to make any sense. + """ + def __init__(self, k=3): + super(DummyAdvisor, self).__init__() + self.k = k + self.random_state = np.random.RandomState() + + def handle_initialize(self, data): + logger.info("Advisor initialized: {}".format(data)) + self.handle_update_search_space(data) + self.parameters_count = 0 + self.parameter_best_metric = defaultdict(float) + self.parameter_cooldown = defaultdict(int) + send(CommandType.Initialized, '') + + def _send_new_trial(self): + self.parameters_count += 1 + new_trial = { + "parameter_id": self.parameters_count, + "parameters": { + "optimizer": param.choice(self.searchspace_json["optimizer"], self.random_state), + "learning_rate": param.loguniform(self.searchspace_json["learning_rate"][0], + self.searchspace_json["learning_rate"][1], + self.random_state) + }, + "parameter_source": "algorithm" + } + logger.info("New trial sent: {}".format(new_trial)) + send(CommandType.NewTrialJob, json_tricks.dumps(new_trial)) + + def handle_request_trial_jobs(self, data): + logger.info("Request trial jobs: {}".format(data)) + for _ in range(data): + self._send_new_trial() + + def handle_update_search_space(self, data): + logger.info("Search space update: {}".format(data)) + self.searchspace_json = data + + def handle_trial_end(self, data): + logger.info("Trial end: {}".format(data)) # do nothing + + def handle_report_metric_data(self, data): + logger.info("Metric reported: {}".format(data)) + if data['type'] == MetricType.REQUEST_PARAMETER: + raise ValueError("Request parameter not supported") + elif data["type"] == MetricType.PERIODICAL: + parameter_id = data["parameter_id"] + if data["value"] > self.parameter_best_metric[parameter_id]: + self.parameter_best_metric[parameter_id] = data["value"] + self.parameter_cooldown[parameter_id] = 0 + else: + self.parameter_cooldown[parameter_id] += 1 + logger.info("Accuracy dropped, cooldown {}, sending a new trial".format( + self.parameter_cooldown[parameter_id])) + self._send_new_trial() + if self.parameter_cooldown[parameter_id] >= self.k: + logger.info("Send kill signal to {}".format(data)) + send(CommandType.KillTrialJob, json_tricks.dumps(data["trial_job_id"])) diff --git a/examples/tuners/mnist_keras_customized_advisor/mnist_keras.py b/examples/tuners/mnist_keras_customized_advisor/mnist_keras.py new file mode 100644 index 0000000000..ee74a085ca --- /dev/null +++ b/examples/tuners/mnist_keras_customized_advisor/mnist_keras.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import argparse +import logging + +import os +import keras +import numpy as np +from keras import backend as K +from keras.callbacks import TensorBoard +from keras.datasets import mnist +from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D +from keras.models import Sequential + +import nni + +LOG = logging.getLogger('mnist_keras') +K.set_image_data_format('channels_last') +TENSORBOARD_DIR = os.environ['NNI_OUTPUT_DIR'] + +H, W = 28, 28 +NUM_CLASSES = 10 + + +def create_mnist_model(hyper_params, input_shape=(H, W, 1), num_classes=NUM_CLASSES): + """ + Create simple convolutional model + """ + layers = [ + Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape), + Conv2D(64, (3, 3), activation='relu'), + MaxPooling2D(pool_size=(2, 2)), + Flatten(), + Dense(100, activation='relu'), + Dense(num_classes, activation='softmax') + ] + + model = Sequential(layers) + + if hyper_params['optimizer'] == 'Adam': + optimizer = keras.optimizers.Adam(lr=hyper_params['learning_rate']) + else: + optimizer = keras.optimizers.SGD(lr=hyper_params['learning_rate'], momentum=0.9) + model.compile(loss=keras.losses.categorical_crossentropy, optimizer=optimizer, metrics=['accuracy']) + + return model + + +def load_mnist_data(args): + """ + Load MNIST dataset + """ + (x_train, y_train), (x_test, y_test) = mnist.load_data() + + x_train = (np.expand_dims(x_train, -1).astype(np.float) / 255.)[:args.num_train] + x_test = (np.expand_dims(x_test, -1).astype(np.float) / 255.)[:args.num_test] + y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)[:args.num_train] + y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)[:args.num_test] + + LOG.debug('x_train shape: %s', (x_train.shape,)) + LOG.debug('x_test shape: %s', (x_test.shape,)) + + return x_train, y_train, x_test, y_test + + +class SendMetrics(keras.callbacks.Callback): + """ + Keras callback to send metrics to NNI framework + """ + + def on_epoch_end(self, epoch, logs={}): + """ + Run on end of each epoch + """ + LOG.debug(logs) + # Should this be val_acc or val_accuracy? Seems inconsistent behavior of Keras? + nni.report_intermediate_result(logs["val_accuracy"]) + + +def train(args, params): + """ + Train model + """ + x_train, y_train, x_test, y_test = load_mnist_data(args) + model = create_mnist_model(params) + + model.fit(x_train, y_train, batch_size=args.batch_size, epochs=args.epochs, verbose=1, + validation_data=(x_test, y_test), callbacks=[SendMetrics(), TensorBoard(log_dir=TENSORBOARD_DIR)]) + + _, acc = model.evaluate(x_test, y_test, verbose=0) + LOG.debug('Final result is: %d', acc) + nni.report_final_result(acc) + + +def generate_default_params(): + """ + Generate default hyper parameters + """ + return { + 'optimizer': 'Adam', + 'learning_rate': 0.001 + } + + +if __name__ == '__main__': + PARSER = argparse.ArgumentParser() + PARSER.add_argument("--batch_size", type=int, default=200, help="batch size", required=False) + PARSER.add_argument("--epochs", type=int, default=10, help="Train epochs", required=False) + PARSER.add_argument("--num_train", type=int, default=60000, + help="Number of train samples to be used, maximum 60000", required=False) + PARSER.add_argument("--num_test", type=int, default=10000, help="Number of test samples to be used, maximum 10000", + required=False) + + ARGS, UNKNOWN = PARSER.parse_known_args() + + # get parameters from tuner + RECEIVED_PARAMS = nni.get_next_parameter() + LOG.debug(RECEIVED_PARAMS) + PARAMS = generate_default_params() + PARAMS.update(RECEIVED_PARAMS) + # train + train(ARGS, PARAMS) diff --git a/examples/tuners/mnist_keras_customized_advisor/search_space.json b/examples/tuners/mnist_keras_customized_advisor/search_space.json new file mode 100644 index 0000000000..dadb04bc25 --- /dev/null +++ b/examples/tuners/mnist_keras_customized_advisor/search_space.json @@ -0,0 +1,5 @@ +{ + "README": "To demonstrate the flexibility, this search space does not follow the standard definition.", + "optimizer": ["Adam", "SGD"], + "learning_rate": [0.001, 0.1] +} diff --git a/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py b/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py index f596e5ea3b..7e376c6d9e 100644 --- a/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py +++ b/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py @@ -21,18 +21,17 @@ hyperband_advisor.py """ -import sys -import math import copy import logging -import numpy as np -import json_tricks +import math +import sys -from nni.protocol import CommandType, send +import json_tricks +import numpy as np +from nni.common import multi_phase_enabled from nni.msg_dispatcher_base import MsgDispatcherBase -from nni.common import init_logger, multi_phase_enabled +from nni.protocol import CommandType, send from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward -import nni.parameter_expressions as parameter_expressions _logger = logging.getLogger(__name__) @@ -53,6 +52,7 @@ def create_parameter_id(): _next_parameter_id += 1 return _next_parameter_id - 1 + def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-1): """Create a full id for a specific bracket's hyperparameter configuration @@ -77,6 +77,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- increased_id]) return params_id + def json2parameter(ss_spec, random_state): """Randomly generate values for hyperparameters from hyperparameter space i.e., x. @@ -100,7 +101,7 @@ def json2parameter(ss_spec, random_state): _index = random_state.randint(len(_value)) chosen_params = json2parameter(ss_spec[NodeType.VALUE][_index], random_state) else: - chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used + chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used _type)(*(_value + [random_state])) else: chosen_params = dict() @@ -114,6 +115,7 @@ def json2parameter(ss_spec, random_state): chosen_params = copy.deepcopy(ss_spec) return chosen_params + class Bracket(): """A bracket in Hyperband, all the information of a bracket is managed by an instance of this class @@ -137,12 +139,12 @@ def __init__(self, s, s_max, eta, R, optimize_mode): self.bracket_id = s self.s_max = s_max self.eta = eta - self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1) - _epsilon) # pylint: disable=invalid-name - self.r = R / eta**s # pylint: disable=invalid-name + self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon) # pylint: disable=invalid-name + self.r = R / eta ** s # pylint: disable=invalid-name self.i = 0 - self.hyper_configs = [] # [ {id: params}, {}, ... ] - self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] - self.num_configs_to_run = [] # [ n, n, n, ... ] + self.hyper_configs = [] # [ {id: params}, {}, ... ] + self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] + self.num_configs_to_run = [] # [ n, n, n, ... ] self.num_finished_configs = [] # [ n, n, n, ... ] self.optimize_mode = OptimizeMode(optimize_mode) self.no_more_trial = False @@ -153,7 +155,7 @@ def is_completed(self): def get_n_r(self): """return the values of n and r for the next round""" - return math.floor(self.n / self.eta**self.i + _epsilon), math.floor(self.r * self.eta**self.i + _epsilon) + return math.floor(self.n / self.eta ** self.i + _epsilon), math.floor(self.r * self.eta ** self.i + _epsilon) def increase_i(self): """i means the ith round. Increase i by 1""" @@ -185,7 +187,6 @@ def set_config_perf(self, i, parameter_id, seq, value): else: self.configs_perf[i][parameter_id] = [seq, value] - def inform_trial_end(self, i): """If the trial is finished and the corresponding round (i.e., i) has all its trials finished, it will choose the top k trials for the next round (i.e., i+1) @@ -195,16 +196,17 @@ def inform_trial_end(self, i): i: int the ith round """ - global _KEY # pylint: disable=global-statement + global _KEY # pylint: disable=global-statement self.num_finished_configs[i] += 1 - _logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i]) + _logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i, + self.num_finished_configs[i], self.num_configs_to_run[i]) if self.num_finished_configs[i] >= self.num_configs_to_run[i] \ - and self.no_more_trial is False: + and self.no_more_trial is False: # choose candidate configs from finished configs to run in the next round assert self.i == i + 1 this_round_perf = self.configs_perf[i] if self.optimize_mode is OptimizeMode.Maximize: - sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1], reverse=True) # reverse + sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1], reverse=True) # reverse else: sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1]) _logger.debug('bracket %s next round %s, sorted hyper configs: %s', self.bracket_id, self.i, sorted_perf) @@ -214,7 +216,7 @@ def inform_trial_end(self, i): for k in range(next_n): params_id = sorted_perf[k][0] params = self.hyper_configs[i][params_id] - params[_KEY] = next_r # modify r + params[_KEY] = next_r # modify r # generate new id increased_id = params_id.split('_')[-1] new_id = create_bracket_parameter_id(self.bracket_id, self.i, increased_id) @@ -223,7 +225,7 @@ def inform_trial_end(self, i): return [[key, value] for key, value in hyper_configs.items()] return None - def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name + def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name """Randomly generate num hyperparameter configurations from search space Parameters @@ -236,7 +238,7 @@ def get_hyperparameter_configurations(self, num, r, searchspace_json, random_sta list a list of hyperparameter configurations. Format: [[key1, value1], [key2, value2], ...] """ - global _KEY # pylint: disable=global-statement + global _KEY # pylint: disable=global-statement assert self.i == 0 hyperparameter_configs = dict() for _ in range(num): @@ -263,6 +265,7 @@ def _record_hyper_configs(self, hyper_configs): self.num_configs_to_run.append(len(hyper_configs)) self.increase_i() + class Hyperband(MsgDispatcherBase): """Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions. This is an implementation that could fully leverage available resources, i.e., high parallelism. @@ -277,14 +280,15 @@ class Hyperband(MsgDispatcherBase): optimize_mode: str optimize mode, 'maximize' or 'minimize' """ + def __init__(self, R=60, eta=3, optimize_mode='maximize'): """B = (s_max + 1)R""" super(Hyperband, self).__init__() - self.R = R # pylint: disable=invalid-name + self.R = R # pylint: disable=invalid-name self.eta = eta - self.brackets = dict() # dict of Bracket - self.generated_hyper_configs = [] # all the configs waiting for run - self.completed_hyper_configs = [] # all the completed configs + self.brackets = dict() # dict of Bracket + self.generated_hyper_configs = [] # all the configs waiting for run + self.completed_hyper_configs = [] # all the completed configs self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon) self.curr_s = self.s_max @@ -302,12 +306,11 @@ def __init__(self, R=60, eta=3, optimize_mode='maximize'): self.job_id_para_id_map = dict() def handle_initialize(self, data): - """data is search space - + """callback for initializing the advisor Parameters ---------- - data: int - number of trial jobs + data: dict + search space """ self.handle_update_search_space(data) send(CommandType.Initialized, '') @@ -348,14 +351,8 @@ def _get_one_trial_job(self): } return ret - def handle_update_search_space(self, data): """data: JSON object, which is search space - - Parameters - ---------- - data: int - number of trial jobs """ self.searchspace_json = data self.random_state = np.random.RandomState() diff --git a/src/sdk/pynni/nni/msg_dispatcher.py b/src/sdk/pynni/nni/msg_dispatcher.py index 1467b27695..64459e3a57 100644 --- a/src/sdk/pynni/nni/msg_dispatcher.py +++ b/src/sdk/pynni/nni/msg_dispatcher.py @@ -42,8 +42,9 @@ TODO: move this logic to NNI manager ''' + def _sort_history(history): - ret = [ ] + ret = [] for i, _ in enumerate(history): if i in history: ret.append(history[i]) @@ -51,17 +52,20 @@ def _sort_history(history): break return ret + # Tuner global variables _next_parameter_id = 0 _trial_params = {} '''key: trial job ID; value: parameters''' _customized_parameter_ids = set() + def _create_parameter_id(): global _next_parameter_id # pylint: disable=global-statement _next_parameter_id += 1 return _next_parameter_id - 1 + def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None): _trial_params[parameter_id] = params ret = { @@ -77,6 +81,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p ret['parameter_index'] = 0 return json_tricks.dumps(ret) + class MsgDispatcher(MsgDispatcherBase): def __init__(self, tuner, assessor=None): super(MsgDispatcher, self).__init__() @@ -123,7 +128,7 @@ def handle_update_search_space(self, data): def handle_import_data(self, data): """Import additional data for tuning - data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' + data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' """ self.tuner.import_data(data) @@ -154,7 +159,8 @@ def handle_report_metric_data(self, data): param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id']) except NoMoreTrialError: param = None - send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index'])) + send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], + parameter_index=data['parameter_index'])) else: raise ValueError('Data type not supported: {}'.format(data['type'])) @@ -188,7 +194,8 @@ def _handle_final_metric_data(self, data): customized = True else: customized = False - self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, trial_job_id=data.get('trial_job_id')) + self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, + trial_job_id=data.get('trial_job_id')) def _handle_intermediate_metric_data(self, data): """Call assessor to process intermediate results @@ -223,7 +230,8 @@ def _handle_intermediate_metric_data(self, data): _logger.debug('BAD, kill %s', trial_job_id) send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) # notify tuner - _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS) + _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', + dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS) if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true': self._earlystop_notify_tuner(data) else: diff --git a/src/sdk/pynni/nni/msg_dispatcher_base.py b/src/sdk/pynni/nni/msg_dispatcher_base.py index c98749e981..9680494c6c 100644 --- a/src/sdk/pynni/nni/msg_dispatcher_base.py +++ b/src/sdk/pynni/nni/msg_dispatcher_base.py @@ -18,7 +18,6 @@ # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # ================================================================================================== -#import json_tricks import os import threading import logging @@ -39,7 +38,12 @@ QUEUE_LEN_WARNING_MARK = 20 _worker_fast_exit_on_terminate = True + class MsgDispatcherBase(Recoverable): + """This is where tuners and assessors are not defined yet. + Inherits this class to make your own advisor. + """ + def __init__(self): if multi_thread_enabled(): self.pool = ThreadPool() @@ -49,7 +53,8 @@ def __init__(self): self.default_command_queue = Queue() self.assessor_command_queue = Queue() self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) - self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,)) + self.assessor_worker = threading.Thread(target=self.command_queue_worker, + args=(self.assessor_command_queue,)) self.default_worker.start() self.assessor_worker.start() self.worker_exceptions = [] @@ -72,7 +77,8 @@ def run(self): if multi_thread_enabled(): result = self.pool.map_async(self.process_command_thread, [(command, data)]) self.thread_results.append(result) - if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]): + if any([thread_result.ready() and not thread_result.successful() for thread_result in + self.thread_results]): _logger.debug('Caught thread exception') break else: @@ -112,7 +118,8 @@ def command_queue_worker(self, command_queue): def enqueue_command(self, command, data): """Enqueue command into command queues """ - if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'): + if command == CommandType.TrialEnd or ( + command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'): self.assessor_command_queue.put((command, data)) else: self.default_command_queue.put((command, data)) @@ -142,14 +149,14 @@ def process_command(self, command, data): _logger.debug('process_command: command: [{}], data: [{}]'.format(command, data)) command_handlers = { - # Tunner commands: + # Tuner commands: CommandType.Initialize: self.handle_initialize, CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.ImportData: self.handle_import_data, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, - # Tunner/Assessor commands: + # Tuner/Assessor commands: CommandType.ReportMetricData: self.handle_report_metric_data, CommandType.TrialEnd: self.handle_trial_end, @@ -163,22 +170,88 @@ def handle_ping(self, data): pass def handle_initialize(self, data): + """Initialize search space and tuner, if any + This method is meant to be called only once for each experiment, after calling this method, + dispatcher should `send(CommandType.Initialized, '')`, to set the status of the experiment to be "INITIALIZED". + Parameters + ---------- + data: dict + search space + """ raise NotImplementedError('handle_initialize not implemented') def handle_request_trial_jobs(self, data): + """The message dispatcher is demanded to generate `data` trial jobs. + These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`, + where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter". + Semantically, message dispatcher should do this `send` exactly `data` times. + + The JSON sent by this method should follow the format of + { + "parameter_id": 42 + "parameters": { + // this will be received by trial + }, + "parameter_source": "algorithm" // optional + } + Parameters + ---------- + data: int + number of trial jobs + """ raise NotImplementedError('handle_request_trial_jobs not implemented') def handle_update_search_space(self, data): - raise NotImplementedError('handle_update_search_space not implemented') + """This method will be called when search space is updated. + It's recommended to call this method in `handle_initialize` to initialize search space. + *No need to* notify NNI Manager when this update is done. + Parameters + ---------- + data: dict + search space + """ + raise NotImplementedError('handle_update_search_space not implemented') def handle_import_data(self, data): + """Import previous data when experiment is resumed. + Parameters + ---------- + data: list + a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' + """ raise NotImplementedError('handle_import_data not implemented') def handle_add_customized_trial(self, data): + """Experimental API. Not recommended for usage. + """ raise NotImplementedError('handle_add_customized_trial not implemented') def handle_report_metric_data(self, data): + """Called when metric data is reported or new parameters are requested (for multiphase). + When new parameters are requested, this method should send a new parameter. + Parameters + ---------- + data: dict + a dict which contains 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'. + type: can be `MetricType.REQUEST_PARAMETER`, `MetricType.FINAL` or `MetricType.PERIODICAL`. + `REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case, + the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py` + as an example. + Raises + ------ + ValueError + Data type is not supported + """ raise NotImplementedError('handle_report_metric_data not implemented') def handle_trial_end(self, data): + """Called when the state of one of the trials is changed + Parameters + ---------- + data: dict + a dict with keys: trial_job_id, event, hyper_params. + trial_job_id: the id generated by training service. + event: the job’s state. + hyper_params: the string that is sent by message dispatcher during the creation of trials. + """ raise NotImplementedError('handle_trial_end not implemented') diff --git a/src/sdk/pynni/tests/test_assessor.py b/src/sdk/pynni/tests/test_assessor.py index 9f992377cd..f1b2913b7a 100644 --- a/src/sdk/pynni/tests/test_assessor.py +++ b/src/sdk/pynni/tests/test_assessor.py @@ -28,9 +28,9 @@ import json from unittest import TestCase, main +_trials = [] +_end_trials = [] -_trials = [ ] -_end_trials = [ ] class NaiveAssessor(Assessor): def assess_trial(self, trial_job_id, trial_history): @@ -47,12 +47,14 @@ def trial_end(self, trial_job_id, success): _in_buf = BytesIO() _out_buf = BytesIO() + def _reverse_io(): _in_buf.seek(0) _out_buf.seek(0) nni.protocol._out_file = _in_buf nni.protocol._in_file = _out_buf + def _restore_io(): _in_buf.seek(0) _out_buf.seek(0) diff --git a/src/sdk/pynni/tests/test_tuner.py b/src/sdk/pynni/tests/test_tuner.py index c1fd3594ee..f2330bd32c 100644 --- a/src/sdk/pynni/tests/test_tuner.py +++ b/src/sdk/pynni/tests/test_tuner.py @@ -32,7 +32,7 @@ class NaiveTuner(Tuner): def __init__(self): self.param = 0 - self.trial_results = [ ] + self.trial_results = [] self.search_space = None self.accept_customized_trials() @@ -57,12 +57,14 @@ def update_search_space(self, search_space): _in_buf = BytesIO() _out_buf = BytesIO() + def _reverse_io(): _in_buf.seek(0) _out_buf.seek(0) nni.protocol._out_file = _in_buf nni.protocol._in_file = _out_buf + def _restore_io(): _in_buf.seek(0) _out_buf.seek(0) @@ -70,7 +72,6 @@ def _restore_io(): nni.protocol._out_file = _out_buf - class TunerTestCase(TestCase): def test_tuner(self): _reverse_io() # now we are sending to Tuner's incoming stream @@ -94,21 +95,20 @@ def test_tuner(self): self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob') _reverse_io() # now we are receiving from Tuner's outgoing stream - self._assert_params(0, 2, [ ], None) - self._assert_params(1, 4, [ ], None) + self._assert_params(0, 2, [], None) + self._assert_params(1, 4, [], None) command, data = receive() # this one is customized data = json.loads(data) self.assertIs(command, CommandType.NewTrialJob) self.assertEqual(data['parameter_id'], 2) self.assertEqual(data['parameter_source'], 'customized') - self.assertEqual(data['parameters'], { 'param': -1 }) + self.assertEqual(data['parameters'], {'param': -1}) - self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'}) + self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'}) self.assertEqual(len(_out_buf.read()), 0) # no more commands - def _assert_params(self, parameter_id, param, trial_results, search_space): command, data = receive() self.assertIs(command, CommandType.NewTrialJob) From 3274ca3094bf05d4fb9d6afa554a2bd71001b2d8 Mon Sep 17 00:00:00 2001 From: chicm-ms <38930155+chicm-ms@users.noreply.github.com> Date: Wed, 9 Oct 2019 10:08:22 +0800 Subject: [PATCH 2/3] Fix multi phase integration test cases (#1591) * Fix multiphase it cases --- test/config_test/multi_phase/multi_phase_batch.test.yml | 4 ++-- test/config_test/multi_phase/multi_phase_grid.test.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/config_test/multi_phase/multi_phase_batch.test.yml b/test/config_test/multi_phase/multi_phase_batch.test.yml index 089cec9c04..1a488d368a 100644 --- a/test/config_test/multi_phase/multi_phase_batch.test.yml +++ b/test/config_test/multi_phase/multi_phase_batch.test.yml @@ -1,8 +1,8 @@ authorName: nni experimentName: default_test maxExecDuration: 5m -maxTrialNum: 8 -trialConcurrency: 4 +maxTrialNum: 2 +trialConcurrency: 2 searchSpacePath: ./search_space.json tuner: diff --git a/test/config_test/multi_phase/multi_phase_grid.test.yml b/test/config_test/multi_phase/multi_phase_grid.test.yml index 793224e40e..aeb0a0103d 100644 --- a/test/config_test/multi_phase/multi_phase_grid.test.yml +++ b/test/config_test/multi_phase/multi_phase_grid.test.yml @@ -1,8 +1,8 @@ authorName: nni experimentName: default_test maxExecDuration: 5m -maxTrialNum: 8 -trialConcurrency: 4 +maxTrialNum: 2 +trialConcurrency: 2 searchSpacePath: ./search_space.json tuner: From e93d2c25e9301c00bb62c749f815f0258517a218 Mon Sep 17 00:00:00 2001 From: liuzhe-lz <40699903+liuzhe-lz@users.noreply.github.com> Date: Wed, 9 Oct 2019 11:38:26 +0800 Subject: [PATCH 3/3] Merge model compression dev branch to master (#1571) * [Proposal] demo compressor (#1402) model compression * update doc for model compression (#1509) * Update Overview.md * Change Doc (#1510) * refactor compression sdk (#1562) * refactor compression sdk * bugfix * bugfix * update ut * Sync model compression doc and implementation (#1575) * update doc * formatting * bugfix * add import to examples --- azure-pipelines.yml | 10 + docs/en_US/Compressor/AutoCompression.md | 3 + docs/en_US/Compressor/Overview.md | 185 ++++++++++++++++++ docs/en_US/Compressor/Pruner.md | 132 +++++++++++++ docs/en_US/Compressor/Quantizer.md | 78 ++++++++ docs/img/agp_pruner.png | Bin 0 -> 8576 bytes .../model_compress/configure_example.yaml | 9 + examples/model_compress/main_tf_pruner.py | 130 ++++++++++++ examples/model_compress/main_tf_quantizer.py | 117 +++++++++++ examples/model_compress/main_torch_pruner.py | 95 +++++++++ .../model_compress/main_torch_quantizer.py | 87 ++++++++ src/sdk/pynni/nni/compression/__init__.py | 0 .../nni/compression/tensorflow/__init__.py | 3 + .../compression/tensorflow/builtin_pruners.py | 112 +++++++++++ .../tensorflow/builtin_quantizers.py | 74 +++++++ .../nni/compression/tensorflow/compressor.py | 152 ++++++++++++++ .../compression/tensorflow/default_layers.py | 8 + .../pynni/nni/compression/torch/__init__.py | 3 + .../nni/compression/torch/builtin_pruners.py | 131 +++++++++++++ .../compression/torch/builtin_quantizers.py | 76 +++++++ .../pynni/nni/compression/torch/compressor.py | 162 +++++++++++++++ .../nni/compression/torch/default_layers.py | 6 + src/sdk/pynni/tests/test_compressor.py | 116 +++++++++++ 23 files changed, 1689 insertions(+) create mode 100644 docs/en_US/Compressor/AutoCompression.md create mode 100644 docs/en_US/Compressor/Overview.md create mode 100644 docs/en_US/Compressor/Pruner.md create mode 100644 docs/en_US/Compressor/Quantizer.md create mode 100644 docs/img/agp_pruner.png create mode 100644 examples/model_compress/configure_example.yaml create mode 100644 examples/model_compress/main_tf_pruner.py create mode 100644 examples/model_compress/main_tf_quantizer.py create mode 100644 examples/model_compress/main_torch_pruner.py create mode 100644 examples/model_compress/main_torch_quantizer.py create mode 100644 src/sdk/pynni/nni/compression/__init__.py create mode 100644 src/sdk/pynni/nni/compression/tensorflow/__init__.py create mode 100644 src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py create mode 100644 src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py create mode 100644 src/sdk/pynni/nni/compression/tensorflow/compressor.py create mode 100644 src/sdk/pynni/nni/compression/tensorflow/default_layers.py create mode 100644 src/sdk/pynni/nni/compression/torch/__init__.py create mode 100644 src/sdk/pynni/nni/compression/torch/builtin_pruners.py create mode 100644 src/sdk/pynni/nni/compression/torch/builtin_quantizers.py create mode 100644 src/sdk/pynni/nni/compression/torch/compressor.py create mode 100644 src/sdk/pynni/nni/compression/torch/default_layers.py create mode 100644 src/sdk/pynni/tests/test_compressor.py diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 1563e4a0ee..a2932fd217 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -10,6 +10,11 @@ jobs: steps: - script: python3 -m pip install --upgrade pip setuptools --user displayName: 'Install python tools' + - script: | + python3 -m pip install torch==0.4.1 --user + python3 -m pip install torchvision==0.2.1 --user + python3 -m pip install tensorflow==1.12.0 --user + displayName: 'Install dependencies for integration' - script: | source install.sh displayName: 'Install nni toolkit via source code' @@ -50,6 +55,11 @@ jobs: steps: - script: python3 -m pip install --upgrade pip setuptools displayName: 'Install python tools' + - script: | + python3 -m pip install torch==0.4.1 --user + python3 -m pip install torchvision==0.2.1 --user + python3 -m pip install tensorflow --user + displayName: 'Install dependencies for integration' - script: | source install.sh displayName: 'Install nni toolkit via source code' diff --git a/docs/en_US/Compressor/AutoCompression.md b/docs/en_US/Compressor/AutoCompression.md new file mode 100644 index 0000000000..fc24f17211 --- /dev/null +++ b/docs/en_US/Compressor/AutoCompression.md @@ -0,0 +1,3 @@ +# Automatic Model Compression on NNI + +TBD. \ No newline at end of file diff --git a/docs/en_US/Compressor/Overview.md b/docs/en_US/Compressor/Overview.md new file mode 100644 index 0000000000..96453caad5 --- /dev/null +++ b/docs/en_US/Compressor/Overview.md @@ -0,0 +1,185 @@ +# Compressor +NNI provides an easy-to-use toolkit to help user design and use compression algorithms. It supports Tensorflow and PyTorch with unified interface. For users to compress their models, they only need to add several lines in their code. There are some popular model compression algorithms built-in in NNI. Users could further use NNI's auto tuning power to find the best compressed model, which is detailed in [Auto Model Compression](./AutoCompression.md). On the other hand, users could easily customize their new compression algorithms using NNI's interface, refer to the tutorial [here](#customize-new-compression-algorithms). + +## Supported algorithms +We have provided two naive compression algorithms and four popular ones for users, including three pruning algorithms and three quantization algorithms: + +|Name|Brief Introduction of Algorithm| +|---|---| +| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | +| [AGP Pruner](./Pruner.md#agp-pruner) | To prune, or not to prune: exploring the efficacy of pruning for model compression. [Reference Paper](https://arxiv.org/abs/1710.01878)| +| [Sensitivity Pruner](./Pruner.md#sensitivity-pruner) | Learning both Weights and Connections for Efficient Neural Networks. [Reference Paper](https://arxiv.org/abs/1506.02626)| +| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits | +| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)| +| [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)| + +## Usage of built-in compression algorithms + +We use a simple example to show how to modify your trial code in order to apply the compression algorithms. Let's say you want to prune all weight to 80% sparsity with Level Pruner, you can add the following three lines into your code before training your model ([here](https://github.com/microsoft/nni/tree/master/examples/model_compress) is complete code). + +Tensorflow code +```python +from nni.compression.tensorflow import LevelPruner +config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] +pruner = LevelPruner(config_list) +pruner(tf.get_default_graph()) +``` + +PyTorch code +```python +from nni.compression.torch import LevelPruner +config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] +pruner = LevelPruner(config_list) +pruner(model) +``` + +You can use other compression algorithms in the package of `nni.compression`. The algorithms are implemented in both PyTorch and Tensorflow, under `nni.compression.torch` and `nni.compression.tensorflow` respectively. You can refer to [Pruner](./Pruner.md) and [Quantizer](./Quantizer.md) for detail description of supported algorithms. + +The function call `pruner(model)` receives user defined model (in Tensorflow the model can be obtained with `tf.get_default_graph()`, while in PyTorch the model is the defined model class), and the model is modified with masks inserted. Then when you run the model, the masks take effect. The masks can be adjusted at runtime by the algorithms. + +When instantiate a compression algorithm, there is `config_list` passed in. We describe how to write this config below. + +### User configuration for a compression algorithm + +When compressing a model, users may want to specify the ratio for sparsity, to specify different ratios for different types of operations, to exclude certain types of operations, or to compress only a certain types of operations. For users to express these kinds of requirements, we define a configuration specification. It can be seen as a python `list` object, where each element is a `dict` object. In each `dict`, there are some keys commonly supported by NNI compression: + +* __op_types__: This is to specify what types of operations to be compressed. 'default' means following the algorithm's default setting. +* __op_names__: This is to specify by name what operations to be compressed. If this field is omitted, operations will not be filtered by it. +* __exclude__: Default is False. If this field is True, it means the operations with specified types and names will be excluded from the compression. + +There are also other keys in the `dict`, but they are specific for every compression algorithm. For example, some , some. + +The `dict`s in the `list` are applied one by one, that is, the configurations in latter `dict` will overwrite the configurations in former ones on the operations that are within the scope of both of them. + +A simple example of configuration is shown below: +```python +[ + { + 'sparsity': 0.8, + 'op_types': 'default' + }, + { + 'sparsity': 0.6, + 'op_names': ['op_name1', 'op_name2'] + }, + { + 'exclude': True, + 'op_names': ['op_name3'] + } +] +``` +It means following the algorithm's default setting for compressed operations with sparsity 0.8, but for `op_name1` and `op_name2` use sparsity 0.6, and please do not compress `op_name3`. + +### Other APIs + +Some compression algorithms use epochs to control the progress of compression, and some algorithms need to do something after every minibatch. Therefore, we provide another two APIs for users to invoke. One is `update_epoch`, you can use it as follows: + +Tensorflow code +```python +pruner.update_epoch(epoch, sess) +``` +PyTorch code +```python +pruner.update_epoch(epoch) +``` + +The other is `step`, it can be called with `pruner.step()` after each minibatch. Note that not all algorithms need these two APIs, for those that do not need them, calling them is allowed but has no effect. + +__[TODO]__ The last API is for users to export the compressed model. You will get a compressed model when you finish the training using this API. It also exports another file storing the values of masks. + +## Customize new compression algorithms + +To simplify writing a new compression algorithm, we design programming interfaces which are simple but flexible enough. There are interfaces for pruner and quantizer respectively. + +### Pruning algorithm + +If you want to write a new pruning algorithm, you can write a class that inherits `nni.compression.tensorflow.Pruner` or `nni.compression.torch.Pruner` depending on which framework you use. Then, override the member functions with the logic of your algorithm. + +```python +# This is writing a pruner in tensorflow. +# For writing a pruner in PyTorch, you can simply replace +# nni.compression.tensorflow.Pruner with +# nni.compression.torch.Pruner +class YourPruner(nni.compression.tensorflow.Pruner): + def __init__(self, config_list): + # suggest you to use the NNI defined spec for config + super().__init__(config_list) + + def bind_model(self, model): + # this func can be used to remember the model or its weights + # in member variables, for getting their values during training + pass + + def calc_mask(self, weight, config, **kwargs): + # weight is the target weight tensor + # config is the selected dict object in config_list for this layer + # kwargs contains op, op_type, and op_name + # design your mask and return your mask + return your_mask + + # note for pytorch version, there is no sess in input arguments + def update_epoch(self, epoch_num, sess): + pass + + # note for pytorch version, there is no sess in input arguments + def step(self, sess): + # can do some processing based on the model or weights binded + # in the func bind_model + pass +``` + +For the simpliest algorithm, you only need to override `calc_mask`. It receives each layer's weight and selected configuration, as well as op information. You generate the mask for this weight in this function and return. Then NNI applies the mask for you. + +Some algorithms generate mask based on training progress, i.e., epoch number. We provide `update_epoch` for the pruner to be aware of the training progress. + +Some algorithms may want global information for generating masks, for example, all weights of the model (for statistic information), model optimizer's information. NNI supports this requirement using `bind_model`. `bind_model` receives the complete model, thus, it could record any information (e.g., reference to weights) it cares about. Then `step` can process or update the information according to the algorithm. You can refer to [source code of built-in algorithms](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/compressors) for example implementations. + +### Quantization algorithm + +The interface for customizing quantization algorithm is similar to that of pruning algorithms. The only difference is that `calc_mask` is replaced with `quantize_weight`. `quantize_weight` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask. + +``` +# This is writing a Quantizer in tensorflow. +# For writing a Quantizer in PyTorch, you can simply replace +# nni.compression.tensorflow.Quantizer with +# nni.compression.torch.Quantizer +class YourPruner(nni.compression.tensorflow.Quantizer): + def __init__(self, config_list): + # suggest you to use the NNI defined spec for config + super().__init__(config_list) + + def bind_model(self, model): + # this func can be used to remember the model or its weights + # in member variables, for getting their values during training + pass + + def quantize_weight(self, weight, config, **kwargs): + # weight is the target weight tensor + # config is the selected dict object in config_list for this layer + # kwargs contains op, op_type, and op_name + # design your quantizer and return new weight + return new_weight + + # note for pytorch version, there is no sess in input arguments + def update_epoch(self, epoch_num, sess): + pass + + # note for pytorch version, there is no sess in input arguments + def step(self, sess): + # can do some processing based on the model or weights binded + # in the func bind_model + pass + + # you can also design your method + def your_method(self, your_input): + #your code + + def bind_model(self, model): + #preprocess model +``` + +__[TODO]__ Will add another member function `quantize_layer_output`, as some quantization algorithms also quantize layers' output. + +### Usage of user customized compression algorithm + +__[TODO]__ ... diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md new file mode 100644 index 0000000000..59db5b16c8 --- /dev/null +++ b/docs/en_US/Compressor/Pruner.md @@ -0,0 +1,132 @@ +Pruner on NNI Compressor +=== + +## Level Pruner + +This is one basic pruner: you can set a target sparsity level (expressed as a fraction, 0.6 means we will prune 60%). + +We first sort the weights in the specified layer by their absolute values. And then mask to zero the smallest magnitude weights until the desired sparsity level is reached. + +### Usage + +Tensorflow code +``` +from nni.compression.tensorflow import LevelPruner +config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] +pruner = LevelPruner(config_list) +pruner(model_graph) +``` + +PyTorch code +``` +from nni.compression.torch import LevelPruner +config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] +pruner = LevelPruner(config_list) +pruner(model) +``` + +#### User configuration for Level Pruner +* **sparsity:** This is to specify the sparsity operations to be compressed to + +*** + +## AGP Pruner +In [To prune, or not to prune: exploring the efficacy of pruning for model compression](https://arxiv.org/abs/1710.01878), authors Michael Zhu and Suyog Gupta provide an algorithm to prune the weight gradually. + +>We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value si (usually 0) to a final sparsity value sf over a span of n pruning steps, starting at training step t0 and with pruning frequency ∆t: +![](../../img/agp_pruner.png) +>The binary weight masks are updated every ∆t steps as the network is trained to gradually increase the sparsity of the network while allowing the network training steps to recover from any pruning-induced loss in accuracy. In our experience, varying the pruning frequency ∆t between 100 and 1000 training steps had a negligible impact on the final model quality. Once the model achieves the target sparsity sf , the weight masks are no longer updated. The intuition behind this sparsity function in equation + +### Usage +You can prune all weight from %0 to 80% sparsity in 10 epoch with the code below. + +First, you should import pruner and add mask to model. + +Tensorflow code +```python +from nni.compression.tensorflow import AGP_Pruner +config_list = [{ + 'initial_sparsity': 0, + 'final_sparsity': 0.8, + 'start_epoch': 1, + 'end_epoch': 10, + 'frequency': 1, + 'op_types': 'default' +}] +pruner = AGP_Pruner(config_list) +pruner(tf.get_default_graph()) +``` +PyTorch code +```python +from nni.compression.torch import AGP_Pruner +config_list = [{ + 'initial_sparsity': 0, + 'final_sparsity': 0.8, + 'start_epoch': 1, + 'end_epoch': 10, + 'frequency': 1, + 'op_types': 'default' +}] +pruner = AGP_Pruner(config_list) +pruner(model) +``` + +Second, you should add code below to update epoch number when you finish one epoch in your training code. + +Tensorflow code +```python +pruner.update_epoch(epoch, sess) +``` +PyTorch code +```python +pruner.update_epoch(epoch) +``` +You can view example for more information + +#### User configuration for AGP Pruner +* **initial_sparsity:** This is to specify the sparsity when compressor starts to compress +* **final_sparsity:** This is to specify the sparsity when compressor finishes to compress +* **start_epoch:** This is to specify the epoch number when compressor starts to compress +* **end_epoch:** This is to specify the epoch number when compressor finishes to compress +* **frequency:** This is to specify every *frequency* number epochs compressor compress once + +*** + +## Sensitivity Pruner +In [Learning both Weights and Connections for Efficient Neural Networks](https://arxiv.org/abs/1506.02626), author Song Han and provide an algorithm to find the sensitivity of each layer and set the pruning threshold to each layer. + +>We used the sensitivity results to find each layer’s threshold: for example, the smallest threshold was applied to the most sensitive layer, which is the first convolutional layer... The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layer’s weights + +### Usage +You can prune weight step by step and reach one target sparsity by Sensitivity Pruner with the code below. + +Tensorflow code +```python +from nni.compression.tensorflow import SensitivityPruner +config_list = [{ 'sparsity':0.8, 'op_types': 'default' }] +pruner = SensitivityPruner(config_list) +pruner(tf.get_default_graph()) +``` +PyTorch code +```python +from nni.compression.torch import SensitivityPruner +config_list = [{ 'sparsity':0.8, 'op_types': 'default' }] +pruner = SensitivityPruner(config_list) +pruner(model) +``` +Like AGP Pruner, you should update mask information every epoch by adding code below + +Tensorflow code +```python +pruner.update_epoch(epoch, sess) +``` +PyTorch code +```python +pruner.update_epoch(epoch) +``` +You can view example for more information + +#### User configuration for Sensitivity Pruner +* **sparsity:** This is to specify the sparsity operations to be compressed to + +*** diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md new file mode 100644 index 0000000000..be91dcc339 --- /dev/null +++ b/docs/en_US/Compressor/Quantizer.md @@ -0,0 +1,78 @@ +Quantizer on NNI Compressor +=== + +## Naive Quantizer + +We provide Naive Quantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm without any configure. + +### Usage +tensorflow +```python +nni.compressors.tensorflow.NaiveQuantizer()(model_graph) +``` +pytorch +```python +nni.compressors.torch.NaiveQuantizer()(model) +``` + +*** + +## QAT Quantizer +In [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf), authors Benoit Jacob and Skirmantas Kligys provide an algorithm to quantize the model with training. + +>We propose an approach that simulates quantization effects in the forward pass of training. Backpropagation still happens as usual, and all weights and biases are stored in floating point so that they can be easily nudged by small amounts. The forward propagation pass however simulates quantized inference as it will happen in the inference engine, by implementing in floating-point arithmetic the rounding behavior of the quantization scheme +>* Weights are quantized before they are convolved with the input. If batch normalization (see [17]) is used for the layer, the batch normalization parameters are “folded into” the weights before quantization. +>* Activations are quantized at points where they would be during inference, e.g. after the activation function is applied to a convolutional or fully connected layer’s output, or after a bypass connection adds or concatenates the outputs of several layers together such as in ResNets. + + +### Usage +You can quantize your model to 8 bits with the code below before your training code. + +Tensorflow code +```python +from nni.compressors.tensorflow import QAT_Quantizer +config_list = [{ 'q_bits': 8, 'op_types': 'default' }] +quantizer = QAT_Quantizer(config_list) +quantizer(tf.get_default_graph()) +``` +PyTorch code +```python +from nni.compressors.torch import QAT_Quantizer +config_list = [{ 'q_bits': 8, 'op_types': 'default' }] +quantizer = QAT_Quantizer(config_list) +quantizer(model) +``` + +You can view example for more information + +#### User configuration for QAT Quantizer +* **q_bits:** This is to specify the q_bits operations to be quantized to + + +*** + +## DoReFa Quantizer +In [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients](https://arxiv.org/abs/1606.06160), authors Shuchang Zhou and Yuxin Wu provide an algorithm named DoReFa to quantize the weight, activation and gradients with training. + +### Usage +To implement DoReFa Quantizer, you can add code below before your training code + +Tensorflow code +```python +from nni.compressors.tensorflow import DoReFaQuantizer +config_list = [{ 'q_bits': 8, 'op_types': 'default' }] +quantizer = DoReFaQuantizer(config_list) +quantizer(tf.get_default_graph()) +``` +PyTorch code +```python +from nni.compressors.torch import DoReFaQuantizer +config_list = [{ 'q_bits': 8, 'op_types': 'default' }] +quantizer = DoReFaQuantizer(config_list) +quantizer(model) +``` + +You can view example for more information + +#### User configuration for QAT Quantizer +* **q_bits:** This is to specify the q_bits operations to be quantized to diff --git a/docs/img/agp_pruner.png b/docs/img/agp_pruner.png new file mode 100644 index 0000000000000000000000000000000000000000..889f42e7647e705f773b58141b2ef6c3067d3e95 GIT binary patch literal 8576 zcmd6tRa6{7x2Pe(onV7|kl^kFC%C%}9w0D4@WEvugF}EMI3xsz;LZ#|0s%sR!EFc< z++FVEKWD9bU+(K!=b^iIt*)-#yK2`i`6@wAM~x7V1`iDljqsJavH=>}lW3G|f`f&+ zH$KUgMm^Ad4b&9TYR2gPq7oP`3fc;2X!R-hcR);38W*H)?u&*-)c5Z~AN2a_h=#^Y z@k&|2D9~nq30Kc>YZ(i~#qzA0V#=aau^Pc=TaBRT|GesH&4-L)YJOMEwA0l@(ZpRG zHT(sWbc%bLyO6VWO~{_hOu>?xWbTX`YNg^hi{+@k`RC^;JIMtb8#8X?UAVA|D$&lX zs+^6?b+M*AM~L!v7MO8fUk0_u%; zw>#XqYDs-5;h7OP6%|jQNK0A)w-~Y-#0A!C`dhqu^bXM7q0a`N31V=k`3G`?`Ni8qhiG9H@;O{);B@ckG}_CPATcC%E<8s%r-p`!Tgj%Z+24h-S)IU<$F^&bA**vNAj+Ez$gvVk zR$b_0Hab}A?+e@}XhsK;{9sJU>~{r~EzVjW7OHP{Il!JW#Dk}fcRdw&srdk?{Gn5E z>PO#ait+O(oAMZM@rs~HKzpi4`=qelAr*xiS+>QlXq;@|FUCI+WXHJbT4NOVHpWJc z))%5<^B;TnK0#@04q`}(mpBR1by99GO~4br#F3LLk#hkv3Vm28ZJklxCnef^B+etS7ady$mHMeCv|H5FrHxw$}87g~pO6g_wBdbYX3<^)Z-d+k}(CI7= z=tr56@Xw6Ilw4b$i9V1Zd|0_sgdsYUz7KfxnU08{N)=7Z6VhU!%~hL1jy#a;U}*kF$rXV|VlWVqmMzlE${gd$gcn_UYJYkJ zbqYB981Up2~(V&7yk`gfWt&IwE&SSG&_svGoV0l_aV#J$u{}A?xBYd5T`v&e;i+oh# zJ>vZl?@cEJICXK8M-L*?qzzT%TjEmadS3lHlq6>EFE&Q?HxaB@7*W}iHc6FYd3W?? zjJCGVc(HBo@cp+dFS)D^PnF$87#}}~^GG{JrYUtKkcc|ypKTL= zXLBs9qaTyUi>Al5jQ?~nHcMs8ZBazX4I1AH@#sEN9D6yWgw$nJ8h520>mjY)!O1bY z(xJy)h~SLD-h;zv-P~L_ygR1CA8B7S;(tghO4Rb?rj%pqZr_XbnunO=}MHW zk_h3UIfmHSx8B@y^!UzqI#G?J89_7)_x2N6);8xXn1VeW z4#%w%W`BICKD)_)ISKp$GNp{)`cGENgtQK)e+cM$gZo+bbkn!I|B5D$3Pmh$%MPWM z*(Zf2Nn^LHneBFo8$q}f{8foGjBS}`GiUdO)aNfS3Esf<3d*k~kMgGQCy9E;$$X<{z~M8L!1J95Qw(C3Hko3Se2YboUx((!6Z51?VPS0-$N*uR#JOI`32_ z%q=_1DpT}37~EFQEokdiB;0Qq-kz%gv!5N%Yi)*UZ42V;+E35;rEYADv^ zr4z$Pms8&Ic6~eYjRKxFwd}1F(Bfw+SWkq!R#u8ni&xl;?V3I53zp!gPU^P4bztur zs~Xf?i~*MT7l!>j5TB_F{z^e4_nxEL1exfk`awDKeuR{R=}F`*p675MH#KVJr?r@<{*dy{Tl4L~_KNz|e3VUZaHVXDc!`$ zD)ZW4(YBiaaZl5(fIPv&+>K;Ap-&_YW4n`Hq_1X@T`Q;7!@Q&oY5pBOgw~XA4cQNk z+IYsdqOKZQl})1$Xp}Nyx*jm}CelR3GR`X5U7`g&RiUr%K)rMYE^i1rc;2z()}6SD z359Z$K$7KDNA#r4ygj)R38;#!p{Bf2aJft}%NU%p_YxoTa94iF7M93c#LdG@h841W z4SsMERA^HuW?R4;zC?C|XNy;W!fz${)ER|aCo8)j=;iV$Vzz{G>9uY=6Myf<)e;TQ zWaPjL*?5Le(@7XZs9y7LKqH&30yp2F%XZZt{WRG9-2Oo$w|TR0x6tp(FavI&uOMdK?{2m;x93H&aqpbd@t4)Gg$;HhxF=Ms#Pw!14 zxJwOHZ8||iR&Nib0JO;Yt!+n!yr{}kZnKYm{g#IWkn>9&kVVJSt^BFHj1BdaGdQu% z7|5cD?uB^kdf*-O`If;sR5>Q7-EMq(W~l8TRuq9>cX2t0w*qd({M(uhq6%QK63_}K zq0{-!$}C9TR!y)rM7c6xqKRLH)1Fe?ynM~J9N5aK-UEs02q0M#BDjgB1F^`VC;VtD zUa5o>JF(R^W8(s=3x!P>M#`co&r0V%4&)h%2wLKiw6IRk%T#0;~dEHEA>~B=Pozu^Y%Fo>5Pq=eMpi{D{9R= zvxvlxD+Z=Qdcv+S3hpzR1XquJ;j4}fM9i1@B~?fY1065;K&ngeN;3z0cI|5U!kbbq zy4n_-&xREBgh_=@^CoJNdFiub6AEc~MnZj)QBjFL~@ zX&B)Xi(RnuNfGv?20>CSe2>GU)gV;8%8#;Ir$J0t3H`eFS&Ki0OS35^|@W}c|$N^{pyRcuv`RJyPN;m{~Dp8H)x3e=6J@ztL1))W~e(H(s z7q22mJ!L$YpY=tN09J{1VZj3VYj}@JkvMqTB=?wpdUc89zcvLpTjhf06WlO&Gp%72 zk??tiTgs+rkD^r=3nvNs0Rm9vl1N?iAVCg@`4uyU`0%RXk~za0ARqjY0C^A=1bH&q zD+V~Iy(E7a<9azj3hVyRs*Ci{`>CWm`(K5c$Ns0#&8BFd%%_*cRj6t5&ZnkVe!w~z zJR+s4!(+-0Nt<^>Jbh_{b|m)2R@6sUtvpGe7W>4=>~XyvQ3RqH@b zqa9Dn9DWl~TiMwOFRZp9q^a|D!tv%=k6ocV-WK3`8Pfn@l(*xm2H3rb&^D`d`Lcua zMCaFA1%Qa3;<{1^n;WK9JK_LxBRC7eFKeK@|^9zWUMgdY|Zqy!hVWk~K%!e*;Ajhqrb zm_;5C+NKoL_ut|<=6|#BOdnA{4r=)1W(jcA2QO2cLJC&yTbccqb+FNGC$hC$+w}4qJ+{Ju|+)$1;>nBIxoJQpRWzS=5bhHx0GCd~L84IfoN+Ce6>H+^-hJlr}&eP7B8VBk>l6ShT#Ho-B0 z4Krq@p5rpIm;z{}fTs$ETUpg6tGKJ&bB}h`K7R_PqD1nK;-9#b9@evF+Q@rqf!)W! zP@mDC!vnhrKI|xnx$f<#wmYgr+6)X|`V~p{GEiBZg|D*=%#qVoqKR-VkE#Fm=EE_j zGEM4x7hgPoerJU-S`eZA+rRkRMUmaYOkYiL>ISm8fm$Nt{^Qv}C|Mib{c`;PwP1PD ztMi!q%`Z(+ej5C>-*4O70e0a<8+sN(`G*^qp@N5xs%xJL1OJ$CUc|EyPfrsFnP2&> zVQ<$=Vp_jFvlCflPPbWe{?t!TQT-mME<^GcAzA6F-gYyvpuE#qK! zIGHt99e*Rwz@m2e172|6-M2g3PEmdp+Z&Eoi`!MVnps+*4!a1mjwf#4E!3o<-=-#n zgzBX$)>d;P!E3HwGkd2AwtFTXN`FpGi1a>Zdr^!nT=XC9zp9RdAxngf8L?g0t$J{8_MEkM54WQ;E z%QZp3yittAkJn`IN%p)T0t0^n2rS!E+xC4uCQkkwwy&;;OzBC34A2vaW?oj61-TM3 z@L2YEzAGXo)z~@4n!@;?r%~c}vYUo4+lAx6N{WjTeR7v^V}ybhLQQx*0l!Uoht}FcwOxp5Z_4Jd#QsLP|AO zFc@T>g`d`%#j|@QgOqoWRVYk2^XU^W3F&j(k!lT!VbU)G!YIJnKfB>;4bFv z6JjyFHa*XVtdi(*AV6SZKQgxHcf}hcEh_+dm9$WU^xN-eZ=D(197^*U9=vwtMde?} z6MP#Yu6r{0le)!8#;8p^SHN4TD&2%7X4#rA%sd2IotP* z{f~pC2k*^Zn3bI&#iMrfRuz1%W8)wcAI2}m;~!rGZggg&*+iF`jY+VshvXV)DAy54 z`;zE~Q5Ec*#JQxib^~5{DEmDJOf4&o<9git1&DYt;NZKCIpwAGQF|)iXSM3Yg8JmT zO}k>jJ=O&uGQ-AEJ5j!HVKfK-vX(%fKQ09@H>kJS!kK0-%#px<28@f9Lgbbz1Zx{oQbj^CrHMS;6bBog!{9=`tcz_n-~Lg_=q>**#vq9w%Z`w6`TYO{<@w zyOE;Pl3zTV);^y$7C_J@o)>$&W}-SIDR80nUJtkUJ}QlE|GC*YHbS>it@Rttif=8< zI|ALAo=iB-#R-JsJ5fW`Wxj?6|`^X6!RwIB8q(?3@5AoRGRz8V34+pe8#kEFy|=dR>|YbWu;dW z0%nxWIB{KX+^~vk4e8etiG#IZ=nASpeAd-)p|T4q4?#ZlN>m>yb*K$GomZe94_LPL zpmSsTTpMvx&O~6cw)l-$5WePGGB3x#aPaY+d5_4{Ht5j|%G}qy>g8lB?#Qy>Qf11H zWA>D%&bEe{>Z73u!AaV>IOg*gd`*FP)7i-PvSgQ``>3^@&>h)8zRpY4rJZaI(jHx*ZTHOBQ;=94>~B((dLE-v{KL-?$H5oCgve<22S?lJX?q1VE<#p!=XiQL)hn5zlvoDTjuadh1~I>1Qi<3VyDDgln{Bm}2zR7H zbnNwyKDUihc}Oz{2nYkN!B~kO@hm7dUU+1PJY*2YXRkfcUqDoTI^$~I!Q~q!m0sOz zq#1cm9Jwalb^HKum#FcE_oN`Y8TQEM{Z7D?LD*jkKP|iJS2x`p$2wpHOWT4dj|UC|5;%5Y3VCL>Hq^R<_dM)1c@w)a*FM zfhQkMrPUF}D+c!R(RDJ@>_okSH}ew!$q+|%{o}{|r8bkG)(NUr_IBV|$30MN$?(85 zjF9bpTKeWFr9PXwk=29A5$G4Db^B2LHlP>&qH=RZxV$?@>Nh32{u<7#ABi~4A=rjx zQMCv9;5@lDmt}g`c{?lo8ZEJG&nw5kfMoaJ@IwsxfMH$F|MdmhZh=OHzM}7H$%kxb zAN|b+e2XF#Q3E<4WORiRqp|V1Wa%9OE3s}d<`)TZYo0o#@qp3lNz3zJ6fL%&d8Du? zk1AeAv8-P&mk&@o4j;1!Qxh}DsQ{rZ;VKehWsH=$~Zd^&B6stsaUJIv?8VqaxC z0QWrMuh+Ma{;k6?=a#@S{?KYt2fYUtvLD0@$nsWx*ybA53o;&}7pi|Zx{eIDW`#6( zZ58ZJ8@>;XUOEL!IpIeOSlbfYdgyxkM7h^q8^bAInE$V}v2!&_V-bfjL>AE^9r2jc1=aB%36GQJ*@lu; z9=ybI&L(Y>Y#>ANp$U`Cn`X^}#Gc(V`h_*Dn$AL5{QS7m-5*m{>8N!zJqdEL{A1HX zGw+-b)La5V-F0uY&Vuo0>GV}Q-g7I@UZ|Q%N@UqV`y%0)l65mOY(0`=md=}|1PIu_@o2XQQP1($+|MUO4r zVMkw=yd2x?Tlq$N_WVW>aW@;pB^D&d zJq9dKW2WSCxB!P?T_jE@T~fn1r#heNci_4^%L(>g-Ru4#zW0zkNk@`~h!Y~qs+?2( z-a?t!n)mfB+MCBdRk!6-mCUl>v{dkPA<7*X@rXamu=NijJ?AwE$1&aXC}KB4C9iG* zOBiyv5y?QkKI4^$$yEJWnh?qtyw={PPu+yC+KNEt8!3>yq6*FAlMQDKhxhBXfjl!W z@c14;*J#1_J+ZmMzy0L!wDb8nVgJ3sMcuGvssis$u^ElO*$;7GepQJlRj>Qfwlf~Ee8?3pT*x5NuMuzC!sSbPgn0YEJHpn=bp9yh6E| zU0!Dm;bQ8F)KFm(ODhBr>H;t@5&(%XE7QoGcrw7p3J~jRhw3s0zG|pl`O-=Ig4@1t zA;dqgjS~zZP54;UknotT0^&7!U#3m00j9T?L>S8VNx4x3boC603Ab#W7w>&fz-toV z{}1+IG3%6(AzTXi3D;-W)zZabF5O4bmN95SLf85k&@n_O?lp!e;LvOD`dGr|WCY}< zpu0_M9@%j2@KGSJ(U(oaFlHT?U@~n8J^^P!9K6?IaSNpwWmFe&56T5!8pi)2Rh)ku zX<%^uejR9sjCUOkn&cO6zLJ91P)1e%1rER*dj`5fR+xN(3yF+s=(Luu?oM-j|*<>o(%$4tag6+MT_P@JOLBnC|^$= zD=hD$Y1pE3Q}~-48QY@_&{=3S|I)lfjal_K{EO`scPXM5C9HTT-uSr$#jC`dNqYGA zAOYDiGdkE)KwJ)q7L2kVx$RLrfEIKm39_{%@KBMMRzCmUBAH!(uYz>_4n0^)GQvX+ z3daf0t)e!nf0MCQBJ`v=X#qb=xsP9>{9@b(DSrVgZGrS|q!8WZqY$c;+wAkU{pryX*ii;JQm9G#wUkso{1p=LSXU_v=&c~7!PBw7os8% zj$u1Ki}M0cvOX!2BG=-`el#7%EJy`)3uxW^zoh3x+t38^;{=x3U9x(Sf|WR{gQDmx zc+tw5Ml26y{B~}k&xW^|D*Zhq^S3%Fg+m5uJJofSY}(S5`2I5(sDIM5<__Zv%X5h$ z{q(rg+EkAsee5Xmy03qZ*1;PG-!v*Qtyado8p5p~P+Jc`%CVz({6yeCpbBJQge{`r zH2-@W>k17N^D+zVCgI5E23paekUh5=Fx8cslHFU{PI8Pi?Ru(GVeaaBt9}?PC3EO^ z=S`17yp|}iYN1Yo^LX<88H8zKT$w9}fX?3x>+pLV%nL>RUFGniQ4|h`*&-3e%TZE1 wC92B)9S11xN9p+QV1TRO|5O&<)kpLZM<2qs&_mQ|2{g1= final_sparsity: + _logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity') + return final_sparsity + + now_epoch = tf.minimum(self.now_epoch, tf.constant(end_epoch)) + span = int(((end_epoch - start_epoch-1)//freq)*freq) + assert span > 0 + base = tf.cast(now_epoch - start_epoch, tf.float32) / span + target_sparsity = (final_sparsity + + (initial_sparsity - final_sparsity)* + (tf.pow(1.0 - base, 3))) + return target_sparsity + + def update_epoch(self, epoch, sess): + sess.run(self.assign_handler) + sess.run(tf.assign(self.now_epoch, int(epoch))) + + +class SensitivityPruner(Pruner): + """ + Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks" + https://arxiv.org/pdf/1506.02626v3.pdf + + I.e.: "The pruning threshold is chosen as a quality parameter multiplied + by the standard deviation of a layers weights." + """ + def __init__(self, config_list): + """ + Configure Args: + sparsity: chosen pruning sparsity + """ + super().__init__(config_list) + self.layer_mask = {} + self.assign_handler = [] + + def calc_mask(self, weight, config, op_name, **kwargs): + target_sparsity = config['sparsity'] * tf.math.reduce_std(weight) + mask = tf.get_variable(op_name + '_mask', initializer=tf.ones(weight.shape), trainable=False) + self.layer_mask[op_name] = mask + + weight_assign_handler = tf.assign(weight, mask*weight) + # use control_dependencies so that weight_assign_handler will be executed before mask_update_handler + with tf.control_dependencies([weight_assign_handler]): + threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100) + # stop gradient in case gradient change the mask + new_mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype)) + mask_update_handler = tf.assign(mask, new_mask) + self.assign_handler.append(mask_update_handler) + return mask + + def update_epoch(self, epoch, sess): + sess.run(self.assign_handler) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py new file mode 100644 index 0000000000..3dde1f2f1c --- /dev/null +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py @@ -0,0 +1,74 @@ +import logging +import tensorflow as tf +from .compressor import Quantizer + +__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] + +_logger = logging.getLogger(__name__) + + +class NaiveQuantizer(Quantizer): + """ + quantize weight to 8 bits + """ + def __init__(self, config_list): + super().__init__(config_list) + self.layer_scale = { } + + def quantize_weight(self, weight, config, op_name, **kwargs): + new_scale = tf.reduce_max(tf.abs(weight)) / 127 + scale = tf.maximum(self.layer_scale.get(op_name, tf.constant(0.0)), new_scale) + self.layer_scale[op_name] = scale + orig_type = weight.dtype + return tf.cast(tf.cast(weight / scale, tf.int8), orig_type) * scale + + +class QAT_Quantizer(Quantizer): + """ + Quantizer using the DoReFa scheme, as defined in: + Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference + http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf + """ + def __init__(self, config_list): + """ + Configure Args: + q_bits + """ + super().__init__(config_list) + + def quantize_weight(self, weight, config, **kwargs): + a = tf.stop_gradient(tf.reduce_min(weight)) + b = tf.stop_gradient(tf.reduce_max(weight)) + n = tf.cast(2 ** config['q_bits'], tf.float32) + scale = b-a/(n-1) + + # use gradient_override_map to change round to idetity for gradient + with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}): + qw = tf.round((weight-a)/scale)*scale +a + + return qw + + +class DoReFaQuantizer(Quantizer): + """ + Quantizer using the DoReFa scheme, as defined in: + Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients + (https://arxiv.org/abs/1606.06160) + """ + def __init__(self, config_list): + """ + Configure Args: + q_bits + """ + super().__init__(config_list) + + def quantize_weight(self, weight, config, **kwargs): + a = tf.math.tanh(weight) + b = a/(2*tf.reduce_max(tf.abs(weight))) + 0.5 + + scale = pow(2, config['q_bits'] - 1) + # use gradient_override_map to change round to idetity for gradient + with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}): + qw = tf.round(b*scale)/scale + r_qw = 2 * qw - 1 + return r_qw diff --git a/src/sdk/pynni/nni/compression/tensorflow/compressor.py b/src/sdk/pynni/nni/compression/tensorflow/compressor.py new file mode 100644 index 0000000000..3e8b638054 --- /dev/null +++ b/src/sdk/pynni/nni/compression/tensorflow/compressor.py @@ -0,0 +1,152 @@ +import tensorflow as tf +import logging +from . import default_layers + +_logger = logging.getLogger(__name__) + + +class LayerInfo: + def __init__(self, op): + self.op = op + self.name = op.name + self.type = op.type + + +class Compressor: + """ + Abstract base TensorFlow compressor + """ + def __init__(self, config_list): + self._bound_model = None + self._config_list = config_list + + def __call__(self, model): + self.compress(model) + return model + + def compress(self, model): + """ + Compress given graph with algorithm implemented by subclass. + This will edit the graph. + """ + assert self._bound_model is None, "Each NNI compressor instance can only compress one model" + self._bound_model = model + self.bind_model(model) + for op in model.get_operations(): + layer = LayerInfo(op) + config = self._select_config(layer) + if config is not None: + self._instrument_layer(layer, config) + + def compress_default_graph(self): + """ + Compress the default graph with algorithm implemented by subclass. + This will edit the graph. + """ + self.compress(tf.get_default_graph()) + + + def bind_model(self, model): + """ + This method is called when a model is bound to the compressor. + Users can optionally overload this method to do model-specific initialization. + It is guaranteed that only one model will be bound to each compressor instance. + """ + pass + + def update_epoch(self, epoch, sess): + """ + if user want to update mask every epoch, user can override this method + """ + pass + + def step(self, sess): + """ + if user want to update mask every step, user can override this method + """ + pass + + + def _instrument_layer(self, layer, config): + raise NotImplementedError() + + def _select_config(self, layer): + ret = None + for config in self._config_list: + op_types = config.get('op_types') + if op_types == 'default': + op_types = default_layers.op_weight_index.keys() + if op_types and layer.type not in op_types: + continue + if config.get('op_names') and layer.name not in config['op_names']: + continue + ret = config + if ret is None or ret.get('exclude'): + return None + return ret + + +class Pruner(Compressor): + """ + Abstract base TensorFlow pruner + """ + def __init__(self, config_list): + super().__init__(config_list) + + def calc_mask(self, weight, config, op, op_type, op_name): + """ + Pruners should overload this method to provide mask for weight tensors. + The mask must have the same shape and type comparing to the weight. + It will be applied with `multiply()` operation. + This method works as a subgraph which will be inserted into the bound model. + """ + raise NotImplementedError("Pruners must overload calc_mask()") + + def _instrument_layer(self, layer, config): + """ + it seems the graph editor can only swap edges of nodes or remove all edges from a node + it cannot remove one edge from a node, nor can it assign a new edge to a node + we assume there is a proxy operation between the weight and the Conv2D layer + this is true as long as the weight is `tf.Value` + not sure what will happen if the weight is calculated from other operations + """ + weight_index = _detect_weight_index(layer) + if weight_index is None: + _logger.warning('Failed to detect weight for layer {}'.format(layer.name)) + return + weight_op = layer.op.inputs[weight_index].op + weight = weight_op.inputs[0] + mask = self.calc_mask(weight, config, op=layer.op, op_type=layer.type, op_name=layer.name) + new_weight = weight * mask + tf.contrib.graph_editor.swap_outputs(weight_op, new_weight.op) + + +class Quantizer(Compressor): + """ + Abstract base TensorFlow quantizer + """ + def __init__(self, config_list): + super().__init__(config_list) + + def quantize_weight(self, weight, config, op, op_type, op_name): + raise NotImplementedError("Quantizer must overload quantize_weight()") + + def _instrument_layer(self, layer, config): + weight_index = _detect_weight_index(layer) + if weight_index is None: + _logger.warning('Failed to detect weight for layer {}'.format(layer.name)) + return + weight_op = layer.op.inputs[weight_index].op + weight = weight_op.inputs[0] + new_weight = self.quantize_weight(weight, config, op=layer.op, op_type=layer.type, op_name=layer.name) + tf.contrib.graph_editor.swap_outputs(weight_op, new_weight.op) + + +def _detect_weight_index(layer): + index = default_layers.op_weight_index.get(layer.type) + if index is not None: + return index + weight_indices = [ i for i, op in enumerate(layer.op.inputs) if op.name.endswith('Variable/read') ] + if len(weight_indices) == 1: + return weight_indices[0] + return None diff --git a/src/sdk/pynni/nni/compression/tensorflow/default_layers.py b/src/sdk/pynni/nni/compression/tensorflow/default_layers.py new file mode 100644 index 0000000000..0f44ca2987 --- /dev/null +++ b/src/sdk/pynni/nni/compression/tensorflow/default_layers.py @@ -0,0 +1,8 @@ +op_weight_index = { + 'Conv2D': None, + 'Conv3D': None, + 'DepthwiseConv2dNative': None, + + 'Mul': None, + 'MatMul': None, +} diff --git a/src/sdk/pynni/nni/compression/torch/__init__.py b/src/sdk/pynni/nni/compression/torch/__init__.py new file mode 100644 index 0000000000..baf2f84628 --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/__init__.py @@ -0,0 +1,3 @@ +from .compressor import LayerInfo, Compressor, Pruner, Quantizer +from .builtin_pruners import * +from .builtin_quantizers import * diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py new file mode 100644 index 0000000000..7309ce1eb3 --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -0,0 +1,131 @@ +import logging +import torch +from .compressor import Pruner + +__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ] + +logger = logging.getLogger('torch pruner') + + +class LevelPruner(Pruner): + """Prune to an exact pruning level specification + """ + def __init__(self, config_list): + """ + we suggest user to use json configure list, like [{},{}...], to set configure + format : + [ + { + 'sparsity': 0, + 'support_type': 'default' + }, + { + 'sparsity': 50, + 'support_op': conv1 + } + ] + if you want input multiple configure from file, you'd better use load_configure_file(path) to load + """ + super().__init__(config_list) + + def calc_mask(self, weight, config, **kwargs): + w_abs = weight.abs() + k = int(weight.numel() * config['sparsity']) + if k == 0: + return torch.ones(weight.shape) + threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() + return torch.gt(w_abs, threshold).type(weight.type()) + + +class AGP_Pruner(Pruner): + """ + An automated gradual pruning algorithm that prunes the smallest magnitude + weights to achieve a preset level of network sparsity. + + Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the + efficacy of pruning for model compression", 2017 NIPS Workshop on Machine + Learning of Phones and other Consumer Devices, + https://arxiv.org/pdf/1710.01878.pdf + """ + def __init__(self, config_list): + """ + Configure Args + initial_sparsity + final_sparsity: you should make sure initial_sparsity <= final_sparsity + start_epoch: start epoch numer begin update mask + end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch + frequency: if you want update every 2 epoch, you can set it 2 + """ + super().__init__(config_list) + self.mask_list = {} + self.now_epoch = 1 + + def calc_mask(self, weight, config, op_name, **kwargs): + mask = self.mask_list.get(op_name, torch.ones(weight.shape)) + target_sparsity = self.compute_target_sparsity(config) + k = int(weight.numel() * target_sparsity) + if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: + return mask + # if we want to generate new mask, we should update weigth first + w_abs = weight.abs()*mask + threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() + new_mask = torch.gt(w_abs, threshold).type(weight.type()) + self.mask_list[op_name] = new_mask + return new_mask + + def compute_target_sparsity(self, config): + end_epoch = config.get('end_epoch', 1) + start_epoch = config.get('start_epoch', 1) + freq = config.get('frequency', 1) + final_sparsity = config.get('final_sparsity', 0) + initial_sparsity = config.get('initial_sparsity', 0) + if end_epoch <= start_epoch or initial_sparsity >= final_sparsity: + logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity') + return final_sparsity + + if end_epoch <= self.now_epoch: + return final_sparsity + + span = ((end_epoch - start_epoch-1)//freq)*freq + assert span > 0 + target_sparsity = (final_sparsity + + (initial_sparsity - final_sparsity)* + (1.0 - ((self.now_epoch - start_epoch)/span))**3) + return target_sparsity + + def update_epoch(self, epoch): + if epoch > 0: + self.now_epoch = epoch + + +class SensitivityPruner(Pruner): + """ + Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks" + https://arxiv.org/pdf/1506.02626v3.pdf + + I.e.: "The pruning threshold is chosen as a quality parameter multiplied + by the standard deviation of a layers weights." + """ + def __init__(self, config_list): + """ + configure Args: + sparsity: chosen pruning sparsity + """ + super().__init__(config_list) + self.mask_list = {} + + + def calc_mask(self, weight, config, op_name, **kwargs): + mask = self.mask_list.get(op_name, torch.ones(weight.shape)) + # if we want to generate new mask, we should update weigth first + weight = weight*mask + target_sparsity = config['sparsity'] * torch.std(weight).item() + k = int(weight.numel() * target_sparsity) + if k == 0: + return mask + + w_abs = weight.abs() + threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() + new_mask = torch.gt(w_abs, threshold).type(weight.type()) + self.mask_list[op_name] = new_mask + return new_mask diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py new file mode 100644 index 0000000000..9f2b9ccd95 --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -0,0 +1,76 @@ +import logging +import torch +from .compressor import Quantizer + +__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] + +logger = logging.getLogger(__name__) + + +class NaiveQuantizer(Quantizer): + """ + quantize weight to 8 bits + """ + def __init__(self, config_list): + super().__init__(config_list) + self.layer_scale = {} + + def quantize_weight(self, weight, config, op_name, **kwargs): + new_scale = weight.abs().max() / 127 + scale = max(self.layer_scale.get(op_name, 0), new_scale) + self.layer_scale[op_name] = scale + orig_type = weight.type() # TODO: user layer + return weight.div(scale).type(torch.int8).type(orig_type).mul(scale) + + +class QAT_Quantizer(Quantizer): + """ + Quantizer using the DoReFa scheme, as defined in: + Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference + http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf + """ + def __init__(self, config_list): + """ + Configure Args: + q_bits + """ + super().__init__(config_list) + + def quantize_weight(self, weight, config, **kwargs): + if config['q_bits'] <= 1: + return weight + a = torch.min(weight) + b = torch.max(weight) + n = pow(2, config['q_bits']) + scale = (b-a)/(n-1) + zero_point = a + out = torch.round((weight - zero_point)/scale) + out = out*scale + zero_point + orig_type = weight.dtype + return out.type(orig_type) + + +class DoReFaQuantizer(Quantizer): + """ + Quantizer using the DoReFa scheme, as defined in: + Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients + (https://arxiv.org/abs/1606.06160) + """ + def __init__(self, config_list): + """ + configure Args: + q_bits + """ + super().__init__(config_list) + + def quantize_weight(self, weight, config, **kwargs): + out = weight.tanh() + out = out /( 2 * out.abs().max()) + 0.5 + out = self.quantize(out, config['q_bits']) + out = 2 * out -1 + return out + + def quantize(self, input_ri, q_bits): + scale = pow(2, q_bits)-1 + output = torch.round(input_ri*scale)/scale + return output diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py new file mode 100644 index 0000000000..6282a2138c --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -0,0 +1,162 @@ +import torch +import logging +from . import default_layers + +_logger = logging.getLogger(__name__) + + +class LayerInfo: + def __init__(self, name, module): + self.module = module + self.name = name + self.type = type(module).__name__ + + self._forward = None + + +class Compressor: + """ + Abstract base PyTorch compressor + """ + def __init__(self, config_list): + self._bound_model = None + self._config_list = config_list + + def __call__(self, model): + self.compress(model) + return model + + def compress(self, model): + """ + Compress the model with algorithm implemented by subclass. + The model will be instrumented and user should never edit it after calling this method. + """ + assert self._bound_model is None, "Each NNI compressor instance can only compress one model" + self._bound_model = model + self.bind_model(model) + for name, module in model.named_modules(): + layer = LayerInfo(name, module) + config = self._select_config(layer) + if config is not None: + self._instrument_layer(layer, config) + + + def bind_model(self, model): + """ + This method is called when a model is bound to the compressor. + Users can optionally overload this method to do model-specific initialization. + It is guaranteed that only one model will be bound to each compressor instance. + """ + pass + + def update_epoch(self, epoch): + """ + if user want to update model every epoch, user can override this method + """ + pass + + def step(self): + """ + if user want to update model every step, user can override this method + """ + pass + + + def _instrument_layer(self, layer, config): + raise NotImplementedError() + + def _select_config(self, layer): + ret = None + for config in self._config_list: + op_types = config.get('op_types') + if op_types == 'default': + op_types = default_layers.weighted_modules + if op_types and layer.type not in op_types: + continue + if config.get('op_names') and layer.name not in config['op_names']: + continue + ret = config + if ret is None or ret.get('exclude'): + return None + return ret + + +class Pruner(Compressor): + """ + Abstract base PyTorch pruner + """ + def __init__(self, config_list): + super().__init__(config_list) + + def calc_mask(self, weight, config, op, op_type, op_name): + """ + Pruners should overload this method to provide mask for weight tensors. + The mask must have the same shape and type comparing to the weight. + It will be applied with `mul()` operation. + This method is effectively hooked to `forward()` method of the model. + """ + raise NotImplementedError("Pruners must overload calc_mask()") + + + def _instrument_layer(self, layer, config): + # TODO: support multiple weight tensors + # create a wrapper forward function to replace the original one + assert layer._forward is None, 'Each model can only be compressed once' + if not _check_weight(layer.module): + _logger.warning('Module {} does not have parameter "weight"'.format(layer.name)) + return + layer._forward = layer.module.forward + + def new_forward(*input): + # apply mask to weight + old_weight = layer.module.weight.data + mask = self.calc_mask(old_weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) + layer.module.weight.data = old_weight.mul(mask) + # calculate forward + ret = layer._forward(*input) + # recover original weight + layer.module.weight.data = old_weight + return ret + + layer.module.forward = new_forward + + +class Quantizer(Compressor): + """ + Base quantizer for pytorch quantizer + """ + def __init__(self, config_list): + super().__init__(config_list) + + def __call__(self, model): + self.compress(model) + return model + + def quantize_weight(self, weight, config, op, op_type, op_name): + """ + user should know where dequantize goes and implement it in quantize method + we now do not provide dequantize method + """ + raise NotImplementedError("Quantizer must overload quantize_weight()") + + def _instrument_layer(self, layer, config): + assert layer._forward is None, 'Each model can only be compressed once' + if not _check_weight(layer.module): + _logger.warning('Module {} does not have parameter "weight"'.format(layer.name)) + return + layer._forward = layer.module.forward + + def new_forward(*input): + weight = layer.module.weight.data + new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) + layer.module.weight.data = new_weight + return layer._forward(*input) + + layer.module.forward = new_forward + + +def _check_weight(module): + try: + return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) + except AttributeError: + return False diff --git a/src/sdk/pynni/nni/compression/torch/default_layers.py b/src/sdk/pynni/nni/compression/torch/default_layers.py new file mode 100644 index 0000000000..185df8bfff --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/default_layers.py @@ -0,0 +1,6 @@ +weighted_modules = [ + 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', + 'Linear', 'Bilinear', + 'PReLU', + 'Embedding', 'EmbeddingBag', +] diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py new file mode 100644 index 0000000000..1c6021b0cd --- /dev/null +++ b/src/sdk/pynni/tests/test_compressor.py @@ -0,0 +1,116 @@ +from unittest import TestCase, main +import nni.compression.tensorflow as tf_compressor +import nni.compression.torch as torch_compressor +import torch +import torch.nn.functional as F +import tensorflow as tf + +def weight_variable(shape): + return tf.Variable(tf.truncated_normal(shape, stddev = 0.1)) + +def bias_variable(shape): + return tf.Variable(tf.constant(0.1, shape = shape)) + +def conv2d(x_input, w_matrix): + return tf.nn.conv2d(x_input, w_matrix, strides = [ 1, 1, 1, 1 ], padding = 'SAME') + +def max_pool(x_input, pool_size): + size = [ 1, pool_size, pool_size, 1 ] + return tf.nn.max_pool(x_input, ksize = size, strides = size, padding = 'SAME') + + +class TfMnist: + def __init__(self): + images = tf.placeholder(tf.float32, [ None, 784 ], name = 'input_x') + labels = tf.placeholder(tf.float32, [ None, 10 ], name = 'input_y') + keep_prob = tf.placeholder(tf.float32, name='keep_prob') + + self.images = images + self.labels = labels + self.keep_prob = keep_prob + + self.train_step = None + self.accuracy = None + + self.w1 = None + self.b1 = None + self.fcw1 = None + self.cross = None + with tf.name_scope('reshape'): + x_image = tf.reshape(images, [ -1, 28, 28, 1 ]) + with tf.name_scope('conv1'): + w_conv1 = weight_variable([ 5, 5, 1, 32 ]) + self.w1 = w_conv1 + b_conv1 = bias_variable([ 32 ]) + self.b1 = b_conv1 + h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) + with tf.name_scope('pool1'): + h_pool1 = max_pool(h_conv1, 2) + with tf.name_scope('conv2'): + w_conv2 = weight_variable([ 5, 5, 32, 64 ]) + b_conv2 = bias_variable([ 64 ]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) + with tf.name_scope('pool2'): + h_pool2 = max_pool(h_conv2, 2) + with tf.name_scope('fc1'): + w_fc1 = weight_variable([ 7 * 7 * 64, 1024 ]) + self.fcw1 = w_fc1 + b_fc1 = bias_variable([ 1024 ]) + h_pool2_flat = tf.reshape(h_pool2, [ -1, 7 * 7 * 64 ]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) + with tf.name_scope('dropout'): + h_fc1_drop = tf.nn.dropout(h_fc1, 0.5) + with tf.name_scope('fc2'): + w_fc2 = weight_variable([ 1024, 10 ]) + b_fc2 = bias_variable([ 10 ]) + y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2 + with tf.name_scope('loss'): + cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = y_conv)) + self.cross = cross_entropy + with tf.name_scope('adam_optimizer'): + self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy) + with tf.name_scope('accuracy'): + correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1)) + self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + +class TorchMnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim = 1) + +class CompressorTestCase(TestCase): + def test_tf_pruner(self): + model = TfMnist() + configure_list = [{'sparsity':0.8, 'op_types':'default'}] + tf_compressor.LevelPruner(configure_list).compress_default_graph() + + + def test_tf_quantizer(self): + model = TfMnist() + tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph() + + def test_torch_pruner(self): + model = TorchMnist() + configure_list = [{'sparsity':0.8, 'op_types':'default'}] + torch_compressor.LevelPruner(configure_list).compress(model) + + def test_torch_quantizer(self): + model = TorchMnist() + torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model) + + +if __name__ == '__main__': + main()