Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Improvement for paddle.nn.layer.state_dict 易用性提升 #64358

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,7 @@ def _state_dict_impl(
structured_name_prefix="",
include_non_persistable_buffer=False,
use_hook=True,
keep_vars=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值如果和Pytorch一样为False,即保存detach的内容。会有问题吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?

OK,提交一下文档的修改PR吧

):
"""
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Expand All @@ -1880,23 +1881,30 @@ def _state_dict_impl(
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True.
include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False.
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True.
keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True.
"""

if destination is None:
destination = collections.OrderedDict()
for name, data in self._parameters.items():
if data is not None:
destination[structured_name_prefix + name] = data
destination[structured_name_prefix + name] = (
data if keep_vars else data.detach()
)
for name, buffer in self._buffers.items():
if not include_non_persistable_buffer:
if (
buffer is not None
and name not in self._non_persistable_buffer_names_set
):
destination[structured_name_prefix + name] = buffer
destination[structured_name_prefix + name] = (
buffer if keep_vars else buffer.detach()
)
else:
if buffer is not None:
destination[structured_name_prefix + name] = buffer
destination[structured_name_prefix + name] = (
buffer if keep_vars else buffer.detach()
)

if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
Expand All @@ -1909,6 +1917,7 @@ def _state_dict_impl(
structured_name_prefix + layer_name + ".",
include_non_persistable_buffer,
use_hook,
keep_vars,
)
)
destination = destination_temp
Expand All @@ -1926,6 +1935,7 @@ def to_static_state_dict(
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
keep_vars=True,
):
'''

Expand All @@ -1935,6 +1945,7 @@ def to_static_state_dict(
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None.
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True.
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True.
keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True.

Returns:
dict, a dict contains all the parameters and persistable buffers.
Expand All @@ -1956,6 +1967,7 @@ def to_static_state_dict(
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=True,
use_hook=use_hook,
keep_vars=keep_vars,
)

def state_dict(
Expand All @@ -1964,6 +1976,7 @@ def state_dict(
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
keep_vars=True,
):
'''
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Expand All @@ -1972,6 +1985,7 @@ def state_dict(
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None.
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True.
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True.
keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True.

Returns:
dict: a dict contains all the parameters and persistable buffers.
Expand All @@ -1993,6 +2007,7 @@ def state_dict(
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=False,
use_hook=use_hook,
keep_vars=keep_vars,
)

@framework.deprecate_stat_dict
Expand Down
41 changes: 41 additions & 0 deletions test/legacy_test/test_state_dict_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def state_dict(
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
keep_vars=True,
):
st = super().state_dict(
destination=destination,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
use_hook=use_hook,
keep_vars=keep_vars,
)
st["linear.new_weight"] = paddle.transpose(
st.pop("linear.weight"), [1, 0]
Expand Down Expand Up @@ -75,6 +77,17 @@ def is_state_dict_equal(model1, model2):
return True


class MyModel3(nn.Layer):
def __init__(self):
super().__init__()
self.linear = nn.Linear(100, 300)
buffer = paddle.to_tensor([0.0])
self.register_buffer("model_buffer", buffer, persistable=True)

def forward(self, x):
return self.linear(x)


class TestStateDictConvert(unittest.TestCase):
def test_main(self):
model1 = MyModel()
Expand All @@ -97,5 +110,33 @@ def test_missing_keys_and_unexpected_keys(self):
self.assertEqual(unexpected_keys[0], "unexpected_keys")


class TestStateKeepVars(unittest.TestCase):
def test_true(self):
model = MyModel3()
x = paddle.randn([5, 100])
y = model(x)
y.backward()
st = model.state_dict()
has_grad = (
(st["linear.weight"].grad == model.linear.weight.grad).all()
and (st["linear.bias"].grad == model.linear.bias.grad).all()
and st["model_buffer"].grad == model.model_buffer.grad
)
self.assertEqual(has_grad, True)

def test_false(self):
model = MyModel3()
x = paddle.randn([5, 100])
y = model(x)
y.backward()
st = model.state_dict(keep_vars=False)
has_grad = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不应该是

  (st["linear.weight"].grad is not None)
            and (st["linear.bias"].grad is not None)
            and (st["model_buffer"].grad is not None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

(st["linear.weight"].grad is not None)
and (st["linear.bias"].grad is not None)
and (st["model_buffer"].grad is not None)
)
self.assertEqual(has_grad, False)


if __name__ == "__main__":
unittest.main()