Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Improvement for paddle.nn.layer.state_dict 易用性提升 #64358

Merged
merged 8 commits into from
May 24, 2024

Conversation

NKNaN
Copy link
Contributor

@NKNaN NKNaN commented May 16, 2024

PR Category

User Experience

PR Types

Improvements

Description

添加参数 keep_vars, 默认值 True,若为 False 则返回的 state_dict 中的 tensor 脱离计算图。
参数默认值与 torch 相反:https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict

Copy link

paddle-bot bot commented May 16, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 16, 2024
@NKNaN NKNaN changed the title API Improvement for paddle.nn.layer.state_dict API Improvement for paddle.nn.layer.state_dict 易用性提升 May 16, 2024
y = model(x)
y.backward()
st = model.state_dict(keep_vars=False)
detached_from_graph = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个比较方式还是不太好看明白,建议直接:

has_grad = (st["linear.weight"].grad == model.linear.weight.grad).all() and (st["linear.bias"].grad == model.linear.bias.grad).all() and st["model_buffer"].grad == model.model_buffer.grad
self.assertEqual(has_grad, False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里因为detach了之后 .grad 返回的是 None,所以改成这样吧:
has_grad = (
(st["linear.weight"].grad is not None)
or (st["linear.bias"].grad is not None)
or (st["model_buffer"].grad is not None)
)

y = model(x)
y.backward()
st = model.state_dict()
detached_from_graph = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个比较方式还是不太好看明白,建议直接:

has_grad = (st["linear.weight"].grad == model.linear.weight.grad).all() and (st["linear.bias"].grad == model.linear.bias.grad).all() and st["model_buffer"].grad == model.model_buffer.grad
self.assertEqual(has_grad, True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

y = model(x)
y.backward()
st = model.state_dict(keep_vars=False)
has_grad = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不应该是

  (st["linear.weight"].grad is not None)
            and (st["linear.bias"].grad is not None)
            and (st["model_buffer"].grad is not None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -1871,6 +1871,7 @@ def _state_dict_impl(
structured_name_prefix="",
include_non_persistable_buffer=False,
use_hook=True,
keep_vars=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值如果和Pytorch一样为False,即保存detach的内容。会有问题吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改之前是保存的原始参数,没有detach的,默认值改成False的话会不会对已有代码不兼容?

OK,提交一下文档的修改PR吧

@luotao1 luotao1 merged commit 12ecf2e into PaddlePaddle:develop May 24, 2024
31 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants