-
Notifications
You must be signed in to change notification settings - Fork 11
/
train_mod.py
151 lines (126 loc) · 5.1 KB
/
train_mod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
train_mod.py: Training of the channel model
This program trains both the link state predictor
and path VAE models from the ray tracing data.
"""
import numpy as np
import matplotlib.pyplot as plt
import pickle
import tensorflow as tf
tfk = tf.keras
tfkm = tf.keras.models
tfkl = tf.keras.layers
import tensorflow.keras.backend as K
import argparse
from models import ChanMod
"""
Parse arguments from command line
"""
parser = argparse.ArgumentParser(description='Trains the channel model')
parser.add_argument('--nlatent',action='store',default=20,type=int,\
help='number of latent variables')
parser.add_argument('--npaths_max',action='store',default=20,type=int,\
help='max number of paths per link')
parser.add_argument('--nepochs_link',action='store',default=50,type=int,\
help='number of epochs for training the link model')
parser.add_argument('--lr_link',action='store',default=1e-3,type=float,\
help='learning rate for the link model')
parser.add_argument('--nepochs_path',action='store',default=2000,type=int,\
help='number of epochs for training the path model')
parser.add_argument('--lr_path',action='store',default=1e-4,type=float,\
help='learning rate for the path model')
parser.add_argument('--out_var_min',action='store',default=1e-4,type=float,\
help='min variance in the decoder outputs. Used for conditioning')
parser.add_argument('--init_stddev',action='store',default=10.0,type=float,\
help='weight and bias initialization')
parser.add_argument('--nunits_enc',action='store',nargs='+',\
default=[200,80],type=int,\
help='num hidden units for the encoder')
parser.add_argument('--nunits_dec',action='store',nargs='+',\
default=[80,200],type=int,\
help='num hidden units for the decoder')
parser.add_argument('--nunits_link',action='store',nargs='+',\
default=[50,25],type=int,\
help='num hidden units for the link state predictor')
parser.add_argument('--model_dir',action='store',\
default='model_data', help='directory to store models')
parser.add_argument('--no_fit_link', dest='no_fit_link', action='store_true',\
help="Does not fit the link model")
parser.add_argument('--no_fit_path', dest='no_fit_path', action='store_true',\
help="Does not fit the path model")
parser.add_argument('--checkpoint_period',action='store',default=100,type=int,\
help='Period in epochs for storing checkpoint. A value of 0 indicates no checkpoints')
parser.add_argument('--batch_ind',action='store',default=-1,type=int,\
help='batch index. -1 indicates no batch index')
args = parser.parse_args()
nlatent = args.nlatent
npaths_max = args.npaths_max
nepochs_path = args.nepochs_path
lr_path = args.lr_path
nepochs_link = args.nepochs_link
lr_link = args.lr_link
init_stddev = args.init_stddev
nunits_enc = args.nunits_enc
nunits_dec = args.nunits_dec
nunits_link = args.nunits_link
model_dir = args.model_dir
batch_ind = args.batch_ind
out_var_min = args.out_var_min
fit_link = not args.no_fit_link
fit_path = not args.no_fit_path
checkpoint_period = args.checkpoint_period
# Overwrite parameters based on batch index
# This is used in HPC training with multiple batches
#lr_batch = [1e-3,1e-3,1e-3,1e-4,1e-4,1e-4]
nlatent_batch = [10,10,10,20]
nunits_enc_batch = [[50,20], [100,40], [200,80], [200,80]]
nunits_dec_batch = [[20,50], [40,100], [80,200], [80,200]]
dir_suffix = ['nl10_nu50', 'nl10_nu100', 'nl10_nu200', 'nl20_nu200']
if batch_ind >= 0:
model_dir = ('/scratch/sr663/models_20200726/model_data_%s' % dir_suffix[batch_ind])
#lr_path = lr_batch[batch_ind]
nlatent = nlatent_batch[batch_ind]
nunits_enc = nunits_enc_batch[batch_ind]
nunits_dec = nunits_dec_batch[batch_ind]
print('batch_ind=%d' % batch_ind)
print('model_dir= %s' % model_dir)
print('nunits_enc=%s' % str(nunits_enc))
print('nunits_dec=%s' % str(nunits_dec))
#print('lr=%12.4e' % lr_path)
print('nlatent=%d' % nlatent)
# Load the data
fn = 'train_test_data.p'
with open(fn, 'rb') as fp:
train_data,test_data,pl_max = pickle.load(fp)
"""
Train the link classifier
"""
K.clear_session()
# Construct the channel model object
chan_mod = ChanMod(nlatent=nlatent,pl_max=pl_max, npaths_max=npaths_max,\
nunits_enc=nunits_enc, nunits_dec=nunits_dec,\
nunits_link=nunits_link,\
out_var_min=out_var_min,\
init_bias_stddev=init_stddev,\
init_kernel_stddev=init_stddev, model_dir=model_dir)
if fit_link:
# Build the link model
chan_mod.build_link_mod()
# Fit the link model
chan_mod.fit_link_mod(train_data, test_data, lr=lr_link,\
epochs=nepochs_link)
# Save the link classifier model
chan_mod.save_link_model()
else:
# Load the link model
chan_mod.load_link_model()
"""
Train the path loss model
"""
if fit_path:
chan_mod.build_path_mod()
chan_mod.fit_path_mod(train_data, test_data, lr=lr_path,\
epochs=nepochs_path,\
checkpoint_period=checkpoint_period)
# Save the path loss model
chan_mod.save_path_model()