-
Notifications
You must be signed in to change notification settings - Fork 0
/
parameters.py
69 lines (56 loc) · 2.02 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
'''
Defines all parameters and hyperparameters used in training and evaluation.
Defines the menu function, model checkpoint function and history caching function.
'''
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SBERT_VERSION = "sentence-transformers/paraphrase-mpnet-base-v2"
MAX_SENT_LENGTH = 128
MAX_PARA_LENGTH = 8
BATCH_SIZE = 2
EMB_SIZE = 768
N_HIDDEN = 100
N_EPOCH = 20
CNN_WINDOWS = [2, 3]
CNN_LR = 0.0001
TRANS_LR = 0.0001
TRANS_N_HIDDEN = 50
N_HEAD = 2
TRANS_LAYER = 2
TRANS_DROPOUT = 0.2
POLY_M = 16
POLY_LR = 0.0001
TEST_PARAM = {'batch_size':BATCH_SIZE, 'shuffle': False}
TRAIN_PARAM = {'batch_size':BATCH_SIZE, 'shuffle': False}
VAL_PARAM = {'batch_size':BATCH_SIZE, 'shuffle': False}
def MENU():
''' The menu function. '''
model_dir = ''
hist_dir = ''
print("Please select your option:")
print("1. Train a new model.")
print("2. Continue training the last model.")
print("3. Evaluate the last model.")
option = input('Your Option: ')
model_dir = input('Model Directory:')
if option == '3':
return option, model_dir, hist_dir
else:
hist_dir = input('Training History Directory:')
return option, model_dir, hist_dir
def SAVE_MODEL(mod, opt, dir, val_loss):
''' Model checkpoint function.
@ mod (model object): The model to be saved.
@ opt (optimizer object): The model optimizer to be saved.
@ dir (str): The directory for the checkpoint to be savec.
@ val_loss (float): The last epoch's validation loss for resume of the training.
'''
torch.save({'model_state_dict': mod.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'validation_loss': val_loss}, dir)
def SAVE_HISTORY(his, dir):
''' History chache.
@ his (dictionary): The dictionary that caches the training and validaiton losses.
@ dir (str): The directory to save the history.
'''
torch.save(his, dir)