Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

SymbolBlock.imports ignore_extra & allow_missing #19156

Merged
merged 2 commits into from
Sep 17, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,8 @@ class SymbolBlock(HybridBlock):
>>> print(feat_model(x))
"""
@staticmethod
def imports(symbol_file, input_names, param_file=None, ctx=None):
def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False,
ignore_extra=False):
"""Import model previously saved by `gluon.HybridBlock.export` or
`Module.save_checkpoint` as a `gluon.SymbolBlock` for use in Gluon.

Expand All @@ -1371,6 +1372,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
Path to parameter file.
ctx : Context, default None
The context to initialize `gluon.SymbolBlock` on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.

Returns
-------
Expand Down Expand Up @@ -1404,7 +1410,8 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
ret.collect_params().load(param_file, ctx, allow_missing, ignore_extra, cast_dtype=True,
dtype_source='saved')
return ret

def __repr__(self):
Expand Down