Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
SymbolBlock.imports ignore_extra & allow_missing (#19156)
Browse files Browse the repository at this point in the history
  • Loading branch information
samskalicky committed Sep 17, 2020
1 parent d2e6452 commit 837c7e4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,8 @@ class SymbolBlock(HybridBlock):
>>> print(feat_model(x))
"""
@staticmethod
def imports(symbol_file, input_names, param_file=None, ctx=None):
def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False,
ignore_extra=False):
"""Import model previously saved by `gluon.HybridBlock.export` or
`Module.save_checkpoint` as a `gluon.SymbolBlock` for use in Gluon.
Expand All @@ -1371,6 +1372,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
Path to parameter file.
ctx : Context, default None
The context to initialize `gluon.SymbolBlock` on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
Returns
-------
Expand Down Expand Up @@ -1404,7 +1410,8 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
ret.collect_params().load(param_file, ctx, allow_missing, ignore_extra, cast_dtype=True,
dtype_source='saved')
return ret

def __repr__(self):
Expand Down

0 comments on commit 837c7e4

Please sign in to comment.