-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Dev compression speedup #1999
Dev compression speedup #1999
Conversation
…mpression-speedup
* add data parallel proposal * fix mask_weight bug * add slim pruner support and example * fix typo * fix typo * fix setattr error * fix buffer update * rename instrument_layer and prunerLayerWrapper * fix pylint * update reverse travsal * add wrap and unwrap * add register_buffer API * update docstring * update docstring * add quantizer support * fix typo * update MeanActivationPruner, weight_rank_filter_pruner and example
…t/nni into dev-compression-speedup
…t/nni into dev-compression-speedup
apply_comp = ApplyCompression(model, masks_file) | ||
apply_comp.compress() | ||
|
||
class ApplyCompression(Pruner): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why subclass Pruner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is for validating the correctness of ModelSpeedup, i.e., compare inference result from ModelSpeedup and that from ApplyCompression. ApplyCompression simply applies the masks, from implementation, it can be seen as a simple pruner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can simply be a method to go through all masks and apply to the weight accordingly. Because there is no need to hook forward method like pruner does and all cal_mask does is a dict lookup. To sum up, it is not a pruner and subclassing pruner makes it look weird. A simple function would do exactly the same and make much more sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good suggestion, let me fix it in a folow-up pr.
parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model") | ||
args = parser.parse_args() | ||
|
||
if args.example_name == 'slim': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems *_speedup have a lot same code, may be they can be put in one method? use dynamic module import based on parser arguments?
print('mask elapsed time: ', time.time() - start) | ||
return | ||
else: | ||
#print("model before: ", model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be removed this or replace it with logger.debug and use flag to set logging level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will do it
@Cjkkkk your comments are very helpful, will fix them in the next pr soon. |
TODO:
print
with logging