Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add VAE example #573

Closed
wants to merge 4 commits into from

Conversation

DrRyanHuang
Copy link
Member

@DrRyanHuang DrRyanHuang commented Oct 18, 2023

PR types

New features

PR changes

Others

Describe

目前约完成 1/2 还在精度对齐ing

@paddle-bot
Copy link

paddle-bot bot commented Oct 18, 2023

Thanks for your contribution!

@paddle-bot
Copy link

paddle-bot bot commented Oct 18, 2023

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@DrRyanHuang DrRyanHuang changed the title add vae [Example] Add VAE example Oct 18, 2023
@HydrogenSulfate
Copy link
Collaborator

感谢大佬提PR,可以合并一下最新develop分支的代码,然后案例可以用hydra改造一下,参考这个案例:https://github.com/PaddlePaddle/PaddleScience/tree/develop/examples/bracket

@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Oct 20, 2023
@DrRyanHuang
Copy link
Member Author

@HydrogenSulfate 不好意思,打扰了,AISTUDIO原文用到了 Julia 环境,我这边写test的时候也配置配置一个 Julia 环境,还是怎么操作一下嘞? 😆

原文地址 :https://aistudio.baidu.com/projectdetail/5541961

# julia 依赖
os.environ['JULIA_DEPOT_PATH'] = '/home/aistudio/opt/julia_package'
# pip 依赖
sys.path.append('/home/aistudio/opt/external-libraries')

# julieries
from julia.api import Julia
jl = Julia(compiled_modules=False,runtime="/home/aistudio/opt/julia-1.8.5/bin/julia")
# import julia
from julia import Main

@HydrogenSulfate
Copy link
Collaborator

@HydrogenSulfate 不好意思,打扰了,AISTUDIO原文用到了 Julia 环境,我这边写test的时候也配置配置一个 Julia 环境,还是怎么操作一下嘞? 😆

原文地址 :https://aistudio.baidu.com/projectdetail/5541961

# julia 依赖
os.environ['JULIA_DEPOT_PATH'] = '/home/aistudio/opt/julia_package'
# pip 依赖
sys.path.append('/home/aistudio/opt/external-libraries')

# julieries
from julia.api import Julia
jl = Julia(compiled_modules=False,runtime="/home/aistudio/opt/julia-1.8.5/bin/julia")
# import julia
from julia import Main

@HydrogenSulfate 不好意思,打扰了,AISTUDIO原文用到了 Julia 环境,我这边写test的时候也配置配置一个 Julia 环境,还是怎么操作一下嘞? 😆

原文地址 :https://aistudio.baidu.com/projectdetail/5541961

# julia 依赖
os.environ['JULIA_DEPOT_PATH'] = '/home/aistudio/opt/julia_package'
# pip 依赖
sys.path.append('/home/aistudio/opt/external-libraries')

# julieries
from julia.api import Julia
jl = Julia(compiled_modules=False,runtime="/home/aistudio/opt/julia-1.8.5/bin/julia")
# import julia
from julia import Main

测试文件test.ipynb只是一个测试功能是否正常的文件,本地测试使用,可以不用提交到PR里。

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Oct 25, 2023

@HydrogenSulfate 不好意思,打扰了,AISTUDIO原文用到了 Julia 环境,我这边写test的时候也配置配置一个 Julia 环境,还是怎么操作一下嘞? 😆

原文地址 :https://aistudio.baidu.com/projectdetail/5541961

# julia 依赖
os.environ['JULIA_DEPOT_PATH'] = '/home/aistudio/opt/julia_package'
# pip 依赖
sys.path.append('/home/aistudio/opt/external-libraries')

# julieries
from julia.api import Julia
jl = Julia(compiled_modules=False,runtime="/home/aistudio/opt/julia-1.8.5/bin/julia")
# import julia
from julia import Main

可以将上述路径配置写到hydra yaml文件里,然后通过cfg.JULIA.xxx访问即可。然后在代码里加几行检测路径是否存在的代码,提醒用户正确设置julia路径。此外在案例文档中,可以加上Julia配置这一章节,里面包含了对配置文件的julia路径设置

examples/RegAE/train_new.py Outdated Show resolved Hide resolved
Comment on lines +37 to +40
# def train_mse_func(
# output_dict: Dict[str, "paddle.Tensor"], label_dict: Dict[str, "pgl.Graph"], *args
# ) -> paddle.Tensor:
# return F.mse_loss(output_dict["pred"], label_dict["label"].y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一些无用的注释可以删除

Comment on lines 71 to 77
latent_dim, hidden_dim = 100, 100
# set model
model = ppsci.arch.AutoEncoder(
input_dim=10000,
latent_dim=latent_dim,
hidden_dim=hidden_dim,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoEncoder的初始化超参数移动至RegAE.yaml配置文件里

# pretrained_model_path="./output_AMGNet/checkpoints/latest"
)
# train model
solver.train()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件末尾换行

Comment on lines +30 to +46
class AutoEncoder(base.Arch):
def __init__(self, input_dim, latent_dim, hidden_dim):
super(AutoEncoder, self).__init__()

# encoder
self._encoder_linear = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Tanh(),
)
self._encoder_mu = nn.Linear(hidden_dim, latent_dim)
self._encoder_log_sigma = nn.Linear(hidden_dim, latent_dim)

self._decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, input_dim),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base.Arch的子类需要做成符号化的形式,即需要具有 input_keys、output_keys两个属性

ppsci/data/dataset/npz_dataset.py Outdated Show resolved Hide resolved
ppsci/data/dataset/npz_dataset.py Outdated Show resolved Hide resolved
Comment on lines +334 to +338
def __getitem__(self, idx):
if self.data_type == "train":
return self.train_data[idx]
else:
return self.test_data[idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset需要返回input、label、weight三个数据字典

# 计算mu,log_sigma与 N(0,1)分布的差距
base = paddle.exp(2. * log_sigma) + paddle.pow(mu, 2) - 1. - 2. * log_sigma
loss = 0.5 * paddle.sum(base) / mu.shape[0]
return loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件末尾换行

def __init__(self):
super().__init__(None, None)

def forward(self, mu, log_sigma):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有loss子类的forward写法统一,接受output_dict, label_dict, weight_dict三个变量,然后计算loss并以一个标量Tensor的形式返回。可以将N(0,1)的数据作为标签数据,从label_dict传入

@HydrogenSulfate
Copy link
Collaborator

closed for #660

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants