-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_prior.py
30 lines (24 loc) · 887 Bytes
/
main_prior.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from gradeadreamer.prior import Trainer
import torch
import numpy as np
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=False, default="./configs/prior.yaml", help="path to the yaml config file")
parser.add_argument("--gpu", required=False, default="0")
parser.add_argument("--prompt", required=True, help="prompt")
args, extras = parser.parse_known_args()
# override default config from cli
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
if "gpu_id" not in opt:
opt.gpu_id = args.gpu
opt.prompt = args.prompt
# seed
seed = opt.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
# train
trainer = Trainer(opt)
trainer.train(opt.iters)