Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression / TensorFlow] Support exporting pruned model #3487

Merged
merged 2 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions nni/compression/tensorflow/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def _instrument(self, layer):

return layer

def _uninstrument(self, layer):
# note that ``self._wrappers`` cache is not cleared here,
# so the same wrapper objects will be recovered in next ``self._instrument()`` call
if isinstance(layer, LayerWrapper):
layer._instrumented = False
return self._uninstrument(layer.layer)
if isinstance(layer, tf.keras.Sequential):
return self._uninstrument_sequential(layer)
if isinstance(layer, tf.keras.Model):
return self._uninstrument_model(layer)
return layer

def _instrument_sequential(self, seq):
layers = list(seq.layers) # seq.layers is read-only property
need_rebuild = False
Expand All @@ -97,6 +109,16 @@ def _instrument_sequential(self, seq):
need_rebuild = True
return tf.keras.Sequential(layers) if need_rebuild else seq

def _uninstrument_sequential(self, seq):
layers = list(seq.layers)
rebuilt = False
for i, layer in enumerate(layers):
orig_layer = self._uninstrument(layer)
if orig_layer is not layer:
layers[i] = orig_layer
rebuilt = True
return tf.keras.Sequential(layers) if rebuilt else seq

def _instrument_model(self, model):
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration"
if isinstance(value, tf.keras.layers.Layer):
Expand All @@ -109,6 +131,17 @@ def _instrument_model(self, model):
value[i] = self._instrument(item)
return model

def _uninstrument_model(self, model):
for key, value in list(model.__dict__.items()):
if isinstance(value, tf.keras.layers.Layer):
orig_layer = self._uninstrument(value)
if orig_layer is not value:
setattr(model, key, orig_layer)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, tf.keras.layers.Layer):
value[i] = self._uninstrument(item)
return model

def _select_config(self, layer):
# Find the last matching config block for given layer.
Expand All @@ -129,6 +162,17 @@ def _select_config(self, layer):
return last_match


class LayerWrapper(tf.keras.Model):
"""
Abstract base class of layer wrappers.

Concrete layer wrapper classes must inherit this to support ``isinstance`` check.
"""
def __init__(self):
super().__init__()
self._instrumented = True


class Pruner(Compressor):
"""
Base class for pruning algorithms.
Expand Down Expand Up @@ -167,6 +211,43 @@ def compress(self):
self._update_mask()
return self.compressed_model

def export_model(self, model_path, mask_path=None):
"""
Export pruned model and optionally mask tensors.

Parameters
----------
model_path : path-like
The path passed to ``Model.save()``.
You can use ".h5" extension name to export HDF5 format.
mask_path : path-like or None
Export masks to the path when set.
Because Keras cannot save tensors without a ``Model``,
this will create a model, set all masks as its weights, and then save that model.
Masks in saved model will be named by corresponding layer name in compressed model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so even the mask file has different format between pytorch and tensorflow?...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch masks are exported as PyTorch tensors. So there's no way to make TF masks compatible.


Returns
-------
None
"""
_logger.info('Saving model to %s', model_path)
input_shape = self.compressed_model._build_input_shape # cannot find a public API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is _build_input_shape assigned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tensorflow's private API.

model = self._uninstrument(self.compressed_model)
if input_shape:
model.build(input_shape)
model.save(model_path)
self._instrument(model)

if mask_path is not None:
_logger.info('Saving masks to %s', mask_path)
# can't find "save raw weights" API in tensorflow, so build a simple model
mask_model = tf.keras.Model()
for wrapper in self.wrappers:
setattr(mask_model, wrapper.layer.name, wrapper.masks)
mask_model.save_weights(mask_path)

_logger.info('Done')

def calc_masks(self, wrapper, **kwargs):
"""
Abstract method to be overridden by algorithm. End users should ignore it.
Expand Down Expand Up @@ -199,7 +280,7 @@ def _update_mask(self):
wrapper.masks = masks


class PrunerLayerWrapper(tf.keras.Model):
class PrunerLayerWrapper(LayerWrapper):
"""
Instrumented TF layer.

Expand All @@ -210,8 +291,6 @@ class PrunerLayerWrapper(tf.keras.Model):

Attributes
----------
layer_info : LayerInfo
All static information of the original layer.
layer : tf.keras.layers.Layer
The original layer.
config : JSON object
Expand All @@ -233,6 +312,10 @@ def __init__(self, layer, config, pruner):
_logger.info('Layer detected to compress: %s', self.layer.name)

def call(self, *inputs):
self._update_weights()
return self.layer(*inputs)

def _update_weights(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we split this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because forward should not be the only allowed place to update weights.

new_weights = []
for weight in self.layer.weights:
mask = self.masks.get(weight.name)
Expand All @@ -243,7 +326,6 @@ def call(self, *inputs):
if new_weights and not hasattr(new_weights[0], 'numpy'):
raise RuntimeError('NNI: Compressed model can only run in eager mode')
self.layer.set_weights([weight.numpy() for weight in new_weights])
return self.layer(*inputs)


# TODO: designed to replace `patch_optimizer`
Expand Down
33 changes: 31 additions & 2 deletions test/ut/sdk/test_compressor_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path
import tempfile
import unittest

import numpy as np
Expand All @@ -27,6 +29,9 @@
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10 = tf.constant([[1.0] * 10])

# This tensor is used as input of CNN models
image_tensor = tf.zeros([1, 10, 10, 3])


@unittest.skipIf(tf.__version__[0] != '2', 'Skip TF 1.x setup')
class TfCompressorTestCase(unittest.TestCase):
Expand All @@ -42,13 +47,37 @@ def _test_layer_detection_on_model(self, model):
layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers)
assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types

def test_level_pruner(self):
def test_level_pruner_and_export_correctness(self):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model = build_naive_model()
pruners['level'](model).compress()
pruner = pruners['level'](model)
model = pruner.compress()

x = model(tensor1x10)
assert x.numpy() == 94.5

temp_dir = Path(tempfile.gettempdir())
pruner.export_model(temp_dir / 'model', temp_dir / 'mask')

# because exporting will uninstrument and re-instrument the model,
# we must test the model again
x = model(tensor1x10)
assert x.numpy() == 94.5

# load and test exported model
exported_model = tf.keras.models.load_model(temp_dir / 'model')
x = exported_model(tensor1x10)
assert x.numpy() == 94.5

def test_export_not_crash(self):
for model in [CnnModel(), build_sequential_model()]:
pruner = pruners['level'](model)
model = pruner.compress()
# cannot use model.build(image_tensor.shape) here
# it fails even without compression
# seems TF's bug, not ours
model(image_tensor)
pruner.export_model(tempfile.TemporaryDirectory().name)

try:
from tensorflow.keras import Model, Sequential
Expand Down