diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 393f1c1570453c..e8c263fe033554 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -379,7 +379,9 @@ def keep_name_table(self, value): def _parse_save_configs(configs): - supported_configs = ['output_spec', "with_hook", "use_combine"] + supported_configs = [ + 'output_spec', "with_hook", "combine_params", "clip_extra" + ] # input check for key in configs: @@ -392,7 +394,8 @@ def _parse_save_configs(configs): inner_config = _SaveLoadConfig() inner_config.output_spec = configs.get('output_spec', None) inner_config.with_hook = configs.get('with_hook', False) - inner_config.combine_params = configs.get("use_combine", False) + inner_config.combine_params = configs.get("combine_params", False) + inner_config.clip_extra = configs.get("clip_extra", False) return inner_config @@ -1015,7 +1018,7 @@ def fun(inputs): params_filename=params_filename, export_for_deployment=configs._export_for_deployment, program_only=configs._program_only, - clip_extra=False) + clip_extra=configs.clip_extra) # collect all vars for var in concrete_program.main_program.list_vars(): diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index f467fbe4888e64..fd4129f47ff65f 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1209,7 +1209,7 @@ def test_save_load_finetune_load(self): with unique_name.guard(): net = Net() #save - paddle.jit.save(net, model_path, use_combine=True) + paddle.jit.save(net, model_path, combine_params=True) class LayerLoadFinetune(paddle.nn.Layer):