-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add VAE example #573
Conversation
Thanks for your contribution! |
✅ This PR's description meets the template requirements! |
感谢大佬提PR,可以合并一下最新develop分支的代码,然后案例可以用hydra改造一下,参考这个案例:https://github.com/PaddlePaddle/PaddleScience/tree/develop/examples/bracket |
@HydrogenSulfate 不好意思,打扰了,AISTUDIO原文用到了 原文地址 :https://aistudio.baidu.com/projectdetail/5541961 # julia 依赖
os.environ['JULIA_DEPOT_PATH'] = '/home/aistudio/opt/julia_package'
# pip 依赖
sys.path.append('/home/aistudio/opt/external-libraries')
# julieries
from julia.api import Julia
jl = Julia(compiled_modules=False,runtime="/home/aistudio/opt/julia-1.8.5/bin/julia")
# import julia
from julia import Main |
测试文件test.ipynb只是一个测试功能是否正常的文件,本地测试使用,可以不用提交到PR里。 |
可以将上述路径配置写到hydra yaml文件里,然后通过cfg.JULIA.xxx访问即可。然后在代码里加几行检测路径是否存在的代码,提醒用户正确设置julia路径。此外在案例文档中,可以加上Julia配置这一章节,里面包含了对配置文件的julia路径设置 |
# def train_mse_func( | ||
# output_dict: Dict[str, "paddle.Tensor"], label_dict: Dict[str, "pgl.Graph"], *args | ||
# ) -> paddle.Tensor: | ||
# return F.mse_loss(output_dict["pred"], label_dict["label"].y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一些无用的注释可以删除
examples/RegAE/train_new.py
Outdated
latent_dim, hidden_dim = 100, 100 | ||
# set model | ||
model = ppsci.arch.AutoEncoder( | ||
input_dim=10000, | ||
latent_dim=latent_dim, | ||
hidden_dim=hidden_dim, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AutoEncoder的初始化超参数移动至RegAE.yaml配置文件里
examples/RegAE/train_new.py
Outdated
# pretrained_model_path="./output_AMGNet/checkpoints/latest" | ||
) | ||
# train model | ||
solver.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件末尾换行
class AutoEncoder(base.Arch): | ||
def __init__(self, input_dim, latent_dim, hidden_dim): | ||
super(AutoEncoder, self).__init__() | ||
|
||
# encoder | ||
self._encoder_linear = nn.Sequential( | ||
nn.Linear(input_dim, hidden_dim), | ||
nn.Tanh(), | ||
) | ||
self._encoder_mu = nn.Linear(hidden_dim, latent_dim) | ||
self._encoder_log_sigma = nn.Linear(hidden_dim, latent_dim) | ||
|
||
self._decoder = nn.Sequential( | ||
nn.Linear(latent_dim, hidden_dim), | ||
nn.Tanh(), | ||
nn.Linear(hidden_dim, input_dim), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base.Arch的子类需要做成符号化的形式,即需要具有 input_keys、output_keys两个属性
def __getitem__(self, idx): | ||
if self.data_type == "train": | ||
return self.train_data[idx] | ||
else: | ||
return self.test_data[idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset需要返回input、label、weight三个数据字典
# 计算mu,log_sigma与 N(0,1)分布的差距 | ||
base = paddle.exp(2. * log_sigma) + paddle.pow(mu, 2) - 1. - 2. * log_sigma | ||
loss = 0.5 * paddle.sum(base) / mu.shape[0] | ||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件末尾换行
def __init__(self): | ||
super().__init__(None, None) | ||
|
||
def forward(self, mu, log_sigma): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有loss子类的forward写法统一,接受output_dict, label_dict, weight_dict三个变量,然后计算loss并以一个标量Tensor的形式返回。可以将N(0,1)的数据作为标签数据,从label_dict传入
closed for #660 |
PR types
New features
PR changes
Others
Describe
目前约完成 1/2 还在精度对齐ing