-
Notifications
You must be signed in to change notification settings - Fork 3
/
master.py
76 lines (56 loc) · 2.39 KB
/
master.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -*- coding: utf-8 -*-
# --------------------------------------------------
#
# master.py
#
# Main program that operates validation and test phases of
# attribute label embbedding (ALE) algorithm on APY dataset
#
# Written by cetinsamet -*- cetin.samet@metu.edu.tr
# April, 2019
# --------------------------------------------------
from easydict import EasyDict as edict
from config import MAIN_DATAPATH, VAL_DATAPATH, TEST_DATAPATH, OBJPATH
from decimal import Decimal
import numpy as np
import subprocess
import argparse
import pickle
def prepareData():
__C = edict()
__C.LR = 1e-2
__C.BATCH_SIZE = 64
__C.N_EPOCH = 200
__C.INFO_EPOCH = 1
__C.ALL_CLASS_VEC = MAIN_DATAPATH + 'all_class_vec.mat'
__C.TRAIN_FEATURES = VAL_DATAPATH + 'train_features.mat'
__C.TRAIN_LABELS = VAL_DATAPATH + 'train_labels.mat'
__C.VAL_SEEN_FEATURES = VAL_DATAPATH + 'val_seen_features.mat'
__C.VAL_SEEN_LABELS = VAL_DATAPATH + 'val_seen_labels.mat'
__C.VAL_UNSEEN_FEATURES = VAL_DATAPATH + 'val_unseen_features.mat'
__C.VAL_UNSEEN_LABELS = VAL_DATAPATH + 'val_unseen_labels.mat'
__C.TRAINVAL_FEATURES = TEST_DATAPATH + 'trainval_features.mat'
__C.TRAINVAL_LABELS = TEST_DATAPATH + 'trainval_labels.mat'
__C.TEST_SEEN_FEATURES = TEST_DATAPATH + 'test_seen_features.mat'
__C.TEST_SEEN_LABELS = TEST_DATAPATH + 'test_seen_labels.mat'
__C.TEST_UNSEEN_FEATURES = TEST_DATAPATH + 'test_unseen_features.mat'
__C.TEST_UNSEEN_LABELS = TEST_DATAPATH + 'test_unseen_labels.mat'
return __C
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--mode", \
choices=["validation", "test"], \
default="validation", \
help="select training phase (validation or test)")
args = parser.parse_args()
#'''
__C = prepareData() # <---- Load data
with open(OBJPATH, 'wb') as outfile:
pickle.dump(__C, outfile, pickle.HIGHEST_PROTOCOL)
if args.mode == 'validation': # <-- Perform training on VALIDATION SET
subprocess.call('python3 validation.py', shell=True)
elif args.mode == 'test': # <-- Perform training on TEST SET
subprocess.call('python3 test.py', shell=True)
else:
pass
#'''