diff --git a/python/paddle/nn/layer/layers.py b/python/paddle/nn/layer/layers.py index 3651f772bb462..49a47ed99895a 100644 --- a/python/paddle/nn/layer/layers.py +++ b/python/paddle/nn/layer/layers.py @@ -1871,6 +1871,7 @@ def _state_dict_impl( structured_name_prefix="", include_non_persistable_buffer=False, use_hook=True, + keep_vars=True, ): """ Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1880,23 +1881,30 @@ def _state_dict_impl( include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True. include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False. use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True. + keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True. """ if destination is None: destination = collections.OrderedDict() for name, data in self._parameters.items(): if data is not None: - destination[structured_name_prefix + name] = data + destination[structured_name_prefix + name] = ( + data if keep_vars else data.detach() + ) for name, buffer in self._buffers.items(): if not include_non_persistable_buffer: if ( buffer is not None and name not in self._non_persistable_buffer_names_set ): - destination[structured_name_prefix + name] = buffer + destination[structured_name_prefix + name] = ( + buffer if keep_vars else buffer.detach() + ) else: if buffer is not None: - destination[structured_name_prefix + name] = buffer + destination[structured_name_prefix + name] = ( + buffer if keep_vars else buffer.detach() + ) if include_sublayers: for layer_name, layer_item in self._sub_layers.items(): @@ -1909,6 +1917,7 @@ def _state_dict_impl( structured_name_prefix + layer_name + ".", include_non_persistable_buffer, use_hook, + keep_vars, ) ) destination = destination_temp @@ -1926,6 +1935,7 @@ def to_static_state_dict( include_sublayers=True, structured_name_prefix="", use_hook=True, + keep_vars=True, ): ''' @@ -1935,6 +1945,7 @@ def to_static_state_dict( destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None. include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True. use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True. + keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True. Returns: dict, a dict contains all the parameters and persistable buffers. @@ -1956,6 +1967,7 @@ def to_static_state_dict( structured_name_prefix=structured_name_prefix, include_non_persistable_buffer=True, use_hook=use_hook, + keep_vars=keep_vars, ) def state_dict( @@ -1964,6 +1976,7 @@ def state_dict( include_sublayers=True, structured_name_prefix="", use_hook=True, + keep_vars=True, ): ''' Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1972,6 +1985,7 @@ def state_dict( destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None. include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True. use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True. + keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True. Returns: dict: a dict contains all the parameters and persistable buffers. @@ -1993,6 +2007,7 @@ def state_dict( structured_name_prefix=structured_name_prefix, include_non_persistable_buffer=False, use_hook=use_hook, + keep_vars=keep_vars, ) @framework.deprecate_stat_dict diff --git a/test/legacy_test/test_state_dict_convert.py b/test/legacy_test/test_state_dict_convert.py index 90bdd3c1949f5..14cf6734a895f 100644 --- a/test/legacy_test/test_state_dict_convert.py +++ b/test/legacy_test/test_state_dict_convert.py @@ -35,12 +35,14 @@ def state_dict( include_sublayers=True, structured_name_prefix="", use_hook=True, + keep_vars=True, ): st = super().state_dict( destination=destination, include_sublayers=include_sublayers, structured_name_prefix=structured_name_prefix, use_hook=use_hook, + keep_vars=keep_vars, ) st["linear.new_weight"] = paddle.transpose( st.pop("linear.weight"), [1, 0] @@ -75,6 +77,17 @@ def is_state_dict_equal(model1, model2): return True +class MyModel3(nn.Layer): + def __init__(self): + super().__init__() + self.linear = nn.Linear(100, 300) + buffer = paddle.to_tensor([0.0]) + self.register_buffer("model_buffer", buffer, persistable=True) + + def forward(self, x): + return self.linear(x) + + class TestStateDictConvert(unittest.TestCase): def test_main(self): model1 = MyModel() @@ -97,5 +110,33 @@ def test_missing_keys_and_unexpected_keys(self): self.assertEqual(unexpected_keys[0], "unexpected_keys") +class TestStateKeepVars(unittest.TestCase): + def test_true(self): + model = MyModel3() + x = paddle.randn([5, 100]) + y = model(x) + y.backward() + st = model.state_dict() + has_grad = ( + (st["linear.weight"].grad == model.linear.weight.grad).all() + and (st["linear.bias"].grad == model.linear.bias.grad).all() + and st["model_buffer"].grad == model.model_buffer.grad + ) + self.assertEqual(has_grad, True) + + def test_false(self): + model = MyModel3() + x = paddle.randn([5, 100]) + y = model(x) + y.backward() + st = model.state_dict(keep_vars=False) + has_grad = ( + (st["linear.weight"].grad is not None) + and (st["linear.bias"].grad is not None) + and (st["model_buffer"].grad is not None) + ) + self.assertEqual(has_grad, False) + + if __name__ == "__main__": unittest.main()