-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API Improvement for paddle.nn.layer.state_dict 易用性提升 #64358
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c290082
update state_dict
NKNaN ca9f822
udpate
NKNaN a8ff048
fix test
NKNaN 5c53bc5
fix test
NKNaN 1477830
fix test
NKNaN b5bb2f1
fix test
NKNaN f40a666
update test
NKNaN 34ba1d2
update test
NKNaN File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,12 +35,14 @@ def state_dict( | |
include_sublayers=True, | ||
structured_name_prefix="", | ||
use_hook=True, | ||
keep_vars=True, | ||
): | ||
st = super().state_dict( | ||
destination=destination, | ||
include_sublayers=include_sublayers, | ||
structured_name_prefix=structured_name_prefix, | ||
use_hook=use_hook, | ||
keep_vars=keep_vars, | ||
) | ||
st["linear.new_weight"] = paddle.transpose( | ||
st.pop("linear.weight"), [1, 0] | ||
|
@@ -75,6 +77,17 @@ def is_state_dict_equal(model1, model2): | |
return True | ||
|
||
|
||
class MyModel3(nn.Layer): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear = nn.Linear(100, 300) | ||
buffer = paddle.to_tensor([0.0]) | ||
self.register_buffer("model_buffer", buffer, persistable=True) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
|
||
class TestStateDictConvert(unittest.TestCase): | ||
def test_main(self): | ||
model1 = MyModel() | ||
|
@@ -97,5 +110,33 @@ def test_missing_keys_and_unexpected_keys(self): | |
self.assertEqual(unexpected_keys[0], "unexpected_keys") | ||
|
||
|
||
class TestStateKeepVars(unittest.TestCase): | ||
def test_true(self): | ||
model = MyModel3() | ||
x = paddle.randn([5, 100]) | ||
y = model(x) | ||
y.backward() | ||
st = model.state_dict() | ||
has_grad = ( | ||
(st["linear.weight"].grad == model.linear.weight.grad).all() | ||
and (st["linear.bias"].grad == model.linear.bias.grad).all() | ||
and st["model_buffer"].grad == model.model_buffer.grad | ||
) | ||
self.assertEqual(has_grad, True) | ||
|
||
def test_false(self): | ||
model = MyModel3() | ||
x = paddle.randn([5, 100]) | ||
y = model(x) | ||
y.backward() | ||
st = model.state_dict(keep_vars=False) | ||
has_grad = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个不应该是
吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
(st["linear.weight"].grad is not None) | ||
and (st["linear.bias"].grad is not None) | ||
and (st["model_buffer"].grad is not None) | ||
) | ||
self.assertEqual(has_grad, False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值如果和Pytorch一样为False,即保存detach的内容。会有问题吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK,提交一下文档的修改PR吧