Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix block.export (#17970)
Browse files Browse the repository at this point in the history
* fix block.export

```net.hybridize``` may optimize out some ops. These ops are alive in nn.Block(also nn.HybridBlock), but its names are not contained in symbol's ```arg_names``` list. So ignore these ops except that their name are end with 'running_mean' or 'running_var'.

* Update block.py

let user can save their extra param.

* add allow_extra

add allow_extra to let user decide whether to save extra parameters or not.

* Update block.py

add moving_mean and moving_var when export model with SymbolBlock

* Update python/mxnet/gluon/block.py

typo

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>

* Update block.py

* Update block.py

* Update python/mxnet/gluon/block.py

Co-authored-by: Leonard Lausen <leonard@lausen.nl>

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>
Co-authored-by: Leonard Lausen <leonard@lausen.nl>
  • Loading branch information
3 people authored Sep 1, 2020
1 parent 8379740 commit 5122d32
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,8 +1348,11 @@ def export(self, path, epoch=0, remove_amp_cast=True):
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:{}'.format(name)] = param._reduce()
if name not in aux_names:
warnings.warn('Parameter "{name}" is not found in the graph. '
.format(name=name), stacklevel=3)
else:
arg_dict['aux:%s'%name] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
Expand Down

0 comments on commit 5122d32

Please sign in to comment.