From 837c7e43c1c5b125944b339db94084ffee699129 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Wed, 16 Sep 2020 17:30:06 -0700 Subject: [PATCH] SymbolBlock.imports ignore_extra & allow_missing (#19156) --- python/mxnet/gluon/block.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 41ef2cb15d89..18063aa761e1 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1357,7 +1357,8 @@ class SymbolBlock(HybridBlock): >>> print(feat_model(x)) """ @staticmethod - def imports(symbol_file, input_names, param_file=None, ctx=None): + def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False, + ignore_extra=False): """Import model previously saved by `gluon.HybridBlock.export` or `Module.save_checkpoint` as a `gluon.SymbolBlock` for use in Gluon. @@ -1371,6 +1372,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): Path to parameter file. ctx : Context, default None The context to initialize `gluon.SymbolBlock` on. + allow_missing : bool, default False + Whether to silently skip loading parameters not represents in the file. + ignore_extra : bool, default False + Whether to silently ignore parameters from the file that are not + present in this Block. Returns ------- @@ -1404,7 +1410,8 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names] ret = SymbolBlock(sym, inputs) if param_file is not None: - ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved') + ret.collect_params().load(param_file, ctx, allow_missing, ignore_extra, cast_dtype=True, + dtype_source='saved') return ret def __repr__(self):