Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression / TensorFlow] Support exporting pruned model #3487

Merged
merged 2 commits into from
Apr 9, 2021

Conversation

liuzhe-lz
Copy link
Contributor

@liuzhe-lz liuzhe-lz commented Mar 29, 2021

No description provided.

@liuzhe-lz liuzhe-lz marked this pull request as ready for review April 2, 2021 07:23
None
"""
_logger.info('Saving model to %s', model_path)
input_shape = self.compressed_model._build_input_shape # cannot find a public API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is _build_input_shape assigned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tensorflow's private API.

Export masks to the path when set.
Because Keras cannot save tensors without a ``Model``,
this will create a model, set all masks as its weights, and then save that model.
Masks in saved model will be named by corresponding layer name in compressed model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so even the mask file has different format between pytorch and tensorflow?...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch masks are exported as PyTorch tensors. So there's no way to make TF masks compatible.

self._update_weights()
return self.layer(*inputs)

def _update_weights(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we split this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because forward should not be the only allowed place to update weights.

@SparkSnail SparkSnail merged commit b7062b5 into microsoft:master Apr 9, 2021
@liuzhe-lz liuzhe-lz deleted the tf-compress branch June 17, 2021 03:26
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants