-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API Improvement for paddle.nn.layer.state_dict 易用性提升 #64358
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
y = model(x) | ||
y.backward() | ||
st = model.state_dict(keep_vars=False) | ||
detached_from_graph = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个比较方式还是不太好看明白,建议直接:
has_grad = (st["linear.weight"].grad == model.linear.weight.grad).all() and (st["linear.bias"].grad == model.linear.bias.grad).all() and st["model_buffer"].grad == model.model_buffer.grad
self.assertEqual(has_grad, False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里因为detach了之后 .grad 返回的是 None,所以改成这样吧:
has_grad = (
(st["linear.weight"].grad is not None)
or (st["linear.bias"].grad is not None)
or (st["model_buffer"].grad is not None)
)
y = model(x) | ||
y.backward() | ||
st = model.state_dict() | ||
detached_from_graph = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个比较方式还是不太好看明白,建议直接:
has_grad = (st["linear.weight"].grad == model.linear.weight.grad).all() and (st["linear.bias"].grad == model.linear.bias.grad).all() and st["model_buffer"].grad == model.model_buffer.grad
self.assertEqual(has_grad, True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
y = model(x) | ||
y.backward() | ||
st = model.state_dict(keep_vars=False) | ||
has_grad = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不应该是
(st["linear.weight"].grad is not None)
and (st["linear.bias"].grad is not None)
and (st["model_buffer"].grad is not None)
吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -1871,6 +1871,7 @@ def _state_dict_impl( | |||
structured_name_prefix="", | |||
include_non_persistable_buffer=False, | |||
use_hook=True, | |||
keep_vars=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值如果和Pytorch一样为False,即保存detach的内容。会有问题吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?
OK,提交一下文档的修改PR吧
PR Category
User Experience
PR Types
Improvements
Description
添加参数 keep_vars, 默认值 True,若为 False 则返回的 state_dict 中的 tensor 脱离计算图。
参数默认值与 torch 相反:https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict