-
Notifications
You must be signed in to change notification settings - Fork 0
/
grid_HIDDEN_LR.py
76 lines (63 loc) · 2.59 KB
/
grid_HIDDEN_LR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Grid search
hidden_size, learning_rate
"""
import logging
import importlib
import time
import argparse
import libs.common.utils as utils
from libs import ModelManager as mm
from config.constants import PathKey, HyperParamKey
# parse cmd-line parameters
parser = argparse.ArgumentParser(description="NLP Team Project - Machine Translation - Grid search")
parser.add_argument('-d', '--DATA', dest='data_path',
help='path of data files')
parser.add_argument('-s', '--MSAVE', dest='model_save',
help='path of model checkpoints')
parser.add_argument('-c', '--CONFIG', dest='config_file',
help='config file name (contains basic parameters, normally not tuning here)')
parser.add_argument('-m', '--MODEL', dest='model_type', required=True,
help='type of model that you will tune')
parser.add_argument('-v', '--HIDDEN', dest='hidden_size', required=True)
parser.add_argument('-l', '--LR', dest='learning_rate')
args = parser.parse_args()
# new config
config_new = {}
hparam_new = {}
if getattr(args, 'config_file'):
user_conf = importlib.import_module('config.{}'.format(args.config_file))
config_new.update(user_conf.CONFIG)
hparam_new = user_conf.HPARAM
if getattr(args, 'data_path'):
config_new.update({PathKey.DATA_PATH: args.data_path})
if getattr(args, 'model_save'):
config_new.update({PathKey.MODEL_SAVES: args.model_save})
if getattr(args, 'learning_rate'):
hparam_new[HyperParamKey.ENC_LR] = float(args.learning_rate)
hparam_new[HyperParamKey.DEC_LR] = float(args.learning_rate)
else:
hparam_new[HyperParamKey.DEC_LR] = 1/float(args.hidden_size)
hparam_new[HyperParamKey.ENC_LR] = 1/float(args.hidden_size)
# logger
ts = time.strftime("%m-%d-%H:%M:%S")
# output_fn = '{}gridSearch-{}{}'.format(config_new[PathKey.MODEL_SAVES], args.model_type, ts)
utils.init_logger(logfile=None)
logger = logging.getLogger('__main__')
########################
# Hyper-parameter Lists #
########################
# tune
mgr = mm.ModelManager(hparams=hparam_new, control_overrides=config_new)
hparam_new.update({
HyperParamKey.HIDDEN_SIZE: int(args.hidden_size), # remember to cast to right type!
})
label = utils.hparam_to_label(prefix=args.model_type, hparam_dict=hparam_new)
mgr.hparams.update(hparam_new)
mgr.load_data(mm.loaderRegister.IWSLT)
mgr.new_model(args.model_type, label=label)
mgr.train()
mgr.graph_training_curves()
# mgr.get_results().to_csv(output_fn + '.csv')
logger.info("Single model train complete.\nModel {} {} training report:\n{}\n===\n===\n===".format(
args.model_type, label, mgr.model.output_dict))