This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression / TensorFlow] Support exporting pruned model #3487
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,6 +87,18 @@ def _instrument(self, layer): | |
|
||
return layer | ||
|
||
def _uninstrument(self, layer): | ||
# note that ``self._wrappers`` cache is not cleared here, | ||
# so the same wrapper objects will be recovered in next ``self._instrument()`` call | ||
if isinstance(layer, LayerWrapper): | ||
layer._instrumented = False | ||
return self._uninstrument(layer.layer) | ||
if isinstance(layer, tf.keras.Sequential): | ||
return self._uninstrument_sequential(layer) | ||
if isinstance(layer, tf.keras.Model): | ||
return self._uninstrument_model(layer) | ||
return layer | ||
|
||
def _instrument_sequential(self, seq): | ||
layers = list(seq.layers) # seq.layers is read-only property | ||
need_rebuild = False | ||
|
@@ -97,6 +109,16 @@ def _instrument_sequential(self, seq): | |
need_rebuild = True | ||
return tf.keras.Sequential(layers) if need_rebuild else seq | ||
|
||
def _uninstrument_sequential(self, seq): | ||
layers = list(seq.layers) | ||
rebuilt = False | ||
for i, layer in enumerate(layers): | ||
orig_layer = self._uninstrument(layer) | ||
if orig_layer is not layer: | ||
layers[i] = orig_layer | ||
rebuilt = True | ||
return tf.keras.Sequential(layers) if rebuilt else seq | ||
|
||
def _instrument_model(self, model): | ||
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration" | ||
if isinstance(value, tf.keras.layers.Layer): | ||
|
@@ -109,6 +131,17 @@ def _instrument_model(self, model): | |
value[i] = self._instrument(item) | ||
return model | ||
|
||
def _uninstrument_model(self, model): | ||
for key, value in list(model.__dict__.items()): | ||
if isinstance(value, tf.keras.layers.Layer): | ||
orig_layer = self._uninstrument(value) | ||
if orig_layer is not value: | ||
setattr(model, key, orig_layer) | ||
elif isinstance(value, list): | ||
for i, item in enumerate(value): | ||
if isinstance(item, tf.keras.layers.Layer): | ||
value[i] = self._uninstrument(item) | ||
return model | ||
|
||
def _select_config(self, layer): | ||
# Find the last matching config block for given layer. | ||
|
@@ -129,6 +162,17 @@ def _select_config(self, layer): | |
return last_match | ||
|
||
|
||
class LayerWrapper(tf.keras.Model): | ||
""" | ||
Abstract base class of layer wrappers. | ||
|
||
Concrete layer wrapper classes must inherit this to support ``isinstance`` check. | ||
""" | ||
def __init__(self): | ||
super().__init__() | ||
self._instrumented = True | ||
|
||
|
||
class Pruner(Compressor): | ||
""" | ||
Base class for pruning algorithms. | ||
|
@@ -167,6 +211,43 @@ def compress(self): | |
self._update_mask() | ||
return self.compressed_model | ||
|
||
def export_model(self, model_path, mask_path=None): | ||
""" | ||
Export pruned model and optionally mask tensors. | ||
|
||
Parameters | ||
---------- | ||
model_path : path-like | ||
The path passed to ``Model.save()``. | ||
You can use ".h5" extension name to export HDF5 format. | ||
mask_path : path-like or None | ||
Export masks to the path when set. | ||
Because Keras cannot save tensors without a ``Model``, | ||
this will create a model, set all masks as its weights, and then save that model. | ||
Masks in saved model will be named by corresponding layer name in compressed model. | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
_logger.info('Saving model to %s', model_path) | ||
input_shape = self.compressed_model._build_input_shape # cannot find a public API | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is tensorflow's private API. |
||
model = self._uninstrument(self.compressed_model) | ||
if input_shape: | ||
model.build(input_shape) | ||
model.save(model_path) | ||
self._instrument(model) | ||
|
||
if mask_path is not None: | ||
_logger.info('Saving masks to %s', mask_path) | ||
# can't find "save raw weights" API in tensorflow, so build a simple model | ||
mask_model = tf.keras.Model() | ||
for wrapper in self.wrappers: | ||
setattr(mask_model, wrapper.layer.name, wrapper.masks) | ||
mask_model.save_weights(mask_path) | ||
|
||
_logger.info('Done') | ||
|
||
def calc_masks(self, wrapper, **kwargs): | ||
""" | ||
Abstract method to be overridden by algorithm. End users should ignore it. | ||
|
@@ -199,7 +280,7 @@ def _update_mask(self): | |
wrapper.masks = masks | ||
|
||
|
||
class PrunerLayerWrapper(tf.keras.Model): | ||
class PrunerLayerWrapper(LayerWrapper): | ||
""" | ||
Instrumented TF layer. | ||
|
||
|
@@ -210,8 +291,6 @@ class PrunerLayerWrapper(tf.keras.Model): | |
|
||
Attributes | ||
---------- | ||
layer_info : LayerInfo | ||
All static information of the original layer. | ||
layer : tf.keras.layers.Layer | ||
The original layer. | ||
config : JSON object | ||
|
@@ -233,6 +312,10 @@ def __init__(self, layer, config, pruner): | |
_logger.info('Layer detected to compress: %s', self.layer.name) | ||
|
||
def call(self, *inputs): | ||
self._update_weights() | ||
return self.layer(*inputs) | ||
|
||
def _update_weights(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we split this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because forward should not be the only allowed place to update weights. |
||
new_weights = [] | ||
for weight in self.layer.weights: | ||
mask = self.masks.get(weight.name) | ||
|
@@ -243,7 +326,6 @@ def call(self, *inputs): | |
if new_weights and not hasattr(new_weights[0], 'numpy'): | ||
raise RuntimeError('NNI: Compressed model can only run in eager mode') | ||
self.layer.set_weights([weight.numpy() for weight in new_weights]) | ||
return self.layer(*inputs) | ||
|
||
|
||
# TODO: designed to replace `patch_optimizer` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so even the mask file has different format between pytorch and tensorflow?...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch masks are exported as PyTorch tensors. So there's no way to make TF masks compatible.