-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
86 lines (74 loc) · 2.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from random import shuffle
import torch
import os
import config
# from tensorboardX import SummaryWriter
from SONIC.sonic_model import SONICModel
from dataloader.sonardata import SonarData, SonarDataLoader
from utils import cycle
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm
def train_sonardata(args):
out_folder = os.path.join(args.outdir, args.exp_name)
os.makedirs(out_folder, exist_ok=True)
f = os.path.join(out_folder, 'args.txt')
with open(f, 'w') as file:
for arg in vars(args):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
# sonardata data loader
train_loader = SonarDataLoader(args).load_data()
train_loader_iterator = iter(cycle(train_loader))
# val data loader
args.phase = "val"
val_loader = SonarDataLoader(args).load_data()
val_loader_iterator = iter(cycle(val_loader))
args.phase= "train"
# define model
model = SONICModel(args)
start_step = model.start_step
# val log iteration
val_log_i = 0
# training loop
pbar = tqdm(total=args.n_iters)
for step in range(start_step + 1, start_step + args.n_iters + 1):
if model.model.training==False:
model.model.train()
data = next(train_loader_iterator)
model.set_input(data)
model.optimize_parameters()
model.write_summary(step,pbar)
if step % args.save_interval == 0 and step > 0:
model.save_model(step)
# run val loop
validate_model(model,args,val_log_i,device=model.model.device,val_loader_iterator=val_loader_iterator)
val_log_i = val_log_i + args.n_val_iters
pbar.close()
def validate_model(model, args, val_log_i, device, val_loader_iterator, criterion=None, decoder=None):
model.model.eval()
start_step = val_log_i
for step in range(start_step+1, start_step+args.n_val_iters+1):
data = next(val_loader_iterator)
with torch.no_grad():
model.set_input(data)
model.val_forward()
model.write_summary(step,val_set=True)
model.model.train()
if __name__ == '__main__':
args = config.get_args()
# wandb.login(key="your key here")
run = wandb.init(
name=args.exp_name,
reinit=True,
mode = "disabled",
# resume="must",
# id="wdrb1add",
config=vars(args)
)
train_sonardata(args)