-
Notifications
You must be signed in to change notification settings - Fork 0
/
grid_DROPOUT.py
69 lines (57 loc) · 2.28 KB
/
grid_DROPOUT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Grid search
Dropout for CNN.
"""
import logging
import importlib
import time
import argparse
import libs.common.utils as utils
from libs import ModelManager as mm
from config.constants import PathKey, HyperParamKey
# parse cmd-line parameters
parser = argparse.ArgumentParser(description="NLP Team Project - Machine Translation - Grid search")
parser.add_argument('-d', '--DATA', dest='data_path',
help='path of data files')
parser.add_argument('-s', '--MSAVE', dest='model_save',
help='path of model checkpoints')
parser.add_argument('-c', '--CONFIG', dest='config_file',
help='config file name (contains basic parameters, normally not tuning here)')
parser.add_argument('-m', '--MODEL', dest='model_type', required=True,
help='type of model that you will tune')
parser.add_argument('-d', '--DROPOUT', dest='drop_prob', required=True)
args = parser.parse_args()
# new config
config_new = {}
hparam_new = None
if getattr(args, 'config_file'):
user_conf = importlib.import_module('config.{}'.format(args.config_file))
config_new.update(user_conf.CONFIG)
hparam_new = user_conf.HPARAM
if getattr(args, 'data_path'):
config_new.update({PathKey.DATA_PATH: args.data_path})
if getattr(args, 'model_save'):
config_new.update({PathKey.MODEL_SAVES: args.model_save})
# logger
ts = time.strftime("%m-%d-%H:%M:%S")
# output_fn = '{}gridSearch-{}{}'.format(config_new[PathKey.MODEL_SAVES], args.model_type, ts)
utils.init_logger(logfile=None)
logger = logging.getLogger('__main__')
########################
# Hyper-parameter Lists #
########################
# tune
mgr = mm.ModelManager(hparams=hparam_new, control_overrides=config_new)
hparam_new = {
HyperParamKey.ENC_DROPOUT: float(args.drop_prob), # remember to cast to right type!
HyperParamKey.DEC_DROPOUT: float(args.drop_prob)
}
label = utils.hparam_to_label(prefix=args.model_type, hparam_dict=hparam_new)
mgr.hparams.update(hparam_new)
mgr.load_data(mm.loaderRegister.IWSLT)
mgr.new_model(args.model_type, label=label)
mgr.train()
mgr.graph_training_curves()
# mgr.get_results().to_csv(output_fn + '.csv')
logger.info("Single model train complete.\nModel {} {} training report:\n{}\n===\n===\n===".format(
args.model_type, label, mgr.model.output_dict))