-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression / TensorFlow] Support exporting pruned model #3487
Conversation
None | ||
""" | ||
_logger.info('Saving model to %s', model_path) | ||
input_shape = self.compressed_model._build_input_shape # cannot find a public API |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is _build_input_shape
assigned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is tensorflow's private API.
Export masks to the path when set. | ||
Because Keras cannot save tensors without a ``Model``, | ||
this will create a model, set all masks as its weights, and then save that model. | ||
Masks in saved model will be named by corresponding layer name in compressed model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so even the mask file has different format between pytorch and tensorflow?...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch masks are exported as PyTorch tensors. So there's no way to make TF masks compatible.
self._update_weights() | ||
return self.layer(*inputs) | ||
|
||
def _update_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we split this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because forward should not be the only allowed place to update weights.
No description provided.