-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
105 lines (101 loc) · 3.79 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class Config:
@staticmethod
def get_default_config(args):
config = Config(args)
config.NUM_EPOCHS = 150 # 3000
config.SAVE_EVERY_EPOCHS = 5
config.PATIENCE = 150 # 10
config.BATCH_SIZE = 512
config.TEST_BATCH_SIZE = 256
config.READER_NUM_PARALLEL_BATCHES = 1
config.SHUFFLE_BUFFER_SIZE = 10000
config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
config.MAX_CONTEXTS = 200
config.TARGET_VOCAB_MAX_SIZE = 27000
config.EMBEDDINGS_SIZE = 128
config.RNN_SIZE = 128 * 2 # Two LSTMs to embed paths, each of size 128
config.DECODER_SIZE = 320
config.NUM_DECODER_LAYERS = 1
config.MAX_PATH_LENGTH = 8 + 1
config.MAX_NAME_PARTS = 5
config.MAX_TARGET_PARTS = 6
config.EMBEDDINGS_DROPOUT_KEEP_PROB = 0.75
config.RNN_DROPOUT_KEEP_PROB = 0.5
config.BIRNN = True
config.RANDOM_CONTEXTS = True
config.BEAM_WIDTH = 0
config.USE_MOMENTUM = True
return config
def take_model_hyperparams_from(self, otherConfig):
self.EMBEDDINGS_SIZE = otherConfig.EMBEDDINGS_SIZE
self.RNN_SIZE = otherConfig.RNN_SIZE
self.DECODER_SIZE = otherConfig.DECODER_SIZE
self.NUM_DECODER_LAYERS = otherConfig.NUM_DECODER_LAYERS
self.BIRNN = otherConfig.BIRNN
if self.DATA_NUM_CONTEXTS <= 0:
self.DATA_NUM_CONTEXTS = otherConfig.DATA_NUM_CONTEXTS
def __init__(self, args):
self.NUM_EPOCHS = 0
self.SAVE_EVERY_EPOCHS = 0
self.PATIENCE = 0
self.BATCH_SIZE = 0
self.TEST_BATCH_SIZE = 0
self.READER_NUM_PARALLEL_BATCHES = 0
self.SHUFFLE_BUFFER_SIZE = 0
self.CSV_BUFFER_SIZE = None
self.TRAIN_PATH = args.data_path
self.TEST_PATH = args.test_path if args.test_path is not None else ''
self.DATA_NUM_CONTEXTS = 0
self.MAX_CONTEXTS = 0
self.TARGET_VOCAB_MAX_SIZE = 0
self.EMBEDDINGS_SIZE = 0
self.RNN_SIZE = 0
self.DECODER_SIZE = 0
self.NUM_DECODER_LAYERS = 0
self.SAVE_PATH = args.save_path_prefix
self.LOAD_PATH = args.load_path
self.MAX_PATH_LENGTH = 0
self.MAX_NAME_PARTS = 0
self.MAX_TARGET_PARTS = 0
self.EMBEDDINGS_DROPOUT_KEEP_PROB = 0
self.RNN_DROPOUT_KEEP_PROB = 0
self.BIRNN = False
self.RANDOM_CONTEXTS = True
self.BEAM_WIDTH = 1
self.USE_MOMENTUM = True
self.RELEASE = args.release
self.SUBTOKENS_VOCAB_MAX_SIZE = args.subtoken_words
self.NODES_VOCAB_MAX_SIZE = args.nodes_words
self.LASSO = args.lasso
self.GROUP_LASSO = args.grouplasso
self.THRESHOLD = args.threshold
self.SPARSE_NODES = args.sparse_nodes
self.SPARSE_SUBTOKEN = args.sparse_subtoken
@staticmethod
def get_debug_config(args):
config = Config(args)
config.NUM_EPOCHS = 3000
config.SAVE_EVERY_EPOCHS = 100
config.PATIENCE = 200
config.BATCH_SIZE = 7
config.TEST_BATCH_SIZE = 7
config.READER_NUM_PARALLEL_BATCHES = 1
config.SHUFFLE_BUFFER_SIZE = 10
config.CSV_BUFFER_SIZE = None
config.MAX_CONTEXTS = 5
config.SUBTOKENS_VOCAB_MAX_SIZE = 190000
config.TARGET_VOCAB_MAX_SIZE = 27000
config.EMBEDDINGS_SIZE = 19
config.RNN_SIZE = 10
config.DECODER_SIZE = 11
config.NUM_DECODER_LAYERS = 1
config.MAX_PATH_LENGTH = 8 + 1
config.MAX_NAME_PARTS = 5
config.MAX_TARGET_PARTS = 6
config.EMBEDDINGS_DROPOUT_KEEP_PROB = 1
config.RNN_DROPOUT_KEEP_PROB = 1
config.BIRNN = True
config.RANDOM_CONTEXTS = True
config.BEAM_WIDTH = 0
config.USE_MOMENTUM = False
return config