-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add RegAE example #660
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辛苦修改一下
docs/zh/examples/RegAE.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
markdown文档用 vscode markdownlint 插件格式化一下
examples/RegAE/RegAE.py
Outdated
criterion = nn.MSELoss() | ||
kl_loss = KLLoss() | ||
|
||
|
||
def loss_expr(output_dict, label_dict, weight_dict=None): | ||
|
||
return kl_loss(output_dict) + criterion(output_dict["p_hat"], label_dict["p_hat"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里用F.mse_loss和KLLoss里的函数吧,简化下代码
loss_expr
函数定义放到sup_constraint = ...
的上方,不用作为全局函数- import不要直接导入某函数或者类,通过导入上一级的模块再访问
examples/RegAE/RegAE.py
Outdated
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
# initialize logger | ||
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> eval.log
examples/RegAE/dataloader.py
Outdated
@@ -0,0 +1,142 @@ | |||
""" | |||
输入数据类型 10^5 * 100 * 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据类型为什么是一个乘法数值?不应该是float或者double吗?还是说10^5表示数值范围,后面的100*100表示样本数?感觉表述可以更改得准确一点
examples/RegAE/dataloader.py
Outdated
|
||
def transform(self, data): | ||
mean = ( | ||
paddle.to_tensor(self.mean).type_as(data).to(data.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> to_tensor(self.mean, dtype=data.dtype)
examples/RegAE/dataloader.py
Outdated
self.train_data = data[: self.train_len] | ||
self.test_data = data[self.train_len :] | ||
|
||
self.scaler = ScalerStd() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scaler跟AMP的scaler存在歧义,改为 normalizer
ppsci/arch/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API文档里加一下AutoEncoder
ppsci/arch/vae.py
Outdated
from ppsci.arch import base | ||
|
||
|
||
# copy from AISTUDIO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行删掉吧
ppsci/arch/vae.py
Outdated
class AutoEncoder(base.Arch): | ||
def __init__( | ||
self, | ||
input_keys: Tuple[str, ...], | ||
output_keys: Tuple[str, ...], | ||
input_dim, | ||
latent_dim, | ||
hidden_dim, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
完善docstring、type hint
mu, log_sigma = output_dict["mu"], output_dict["log_sigma"] | ||
|
||
base = paddle.exp(2.0 * log_sigma) + paddle.pow(mu, 2) - 1.0 - 2.0 * log_sigma | ||
loss = 0.5 * paddle.sum(base) / mu.shape[0] | ||
|
||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是KL散度损失吗?但是看起来跟nn.KLDivLoss不是很像,如果不是很通用的Loss的话还是放在案例里通过FunctionalLoss使用吧,如果是比较通用的,就需要写成KLDiv的形式,即
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add RegAE example * add RegAE --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
PR types
New features
PR changes
Others
Describe