diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b9a78bb --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# Default ignored files +/workspace.xml +logs/ +vqa_experiments/__pycache__ +snapshots +*__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index f39e9dd..7c33706 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,16 @@ We save out incremental weights and associated data for REMIND after each evalua 3. Run `run_imagenet_experiment.sh` ## Training REMIND on VQA Datasets +We use the gensen library for question features. Execute the following steps to set it up: +``` +cd ${GENSENPATH} +git clone git@github.com:erobic/gensen.git +cd ${GENSENPATH}/data/embedding +chmod +x glove25.sh && ./glove2h5.sh +cd ${GENSENPATH}/data/models +chmod +x download_models.sh && ./download_models.sh +``` + ### Training REMIND on CLEVR _Note: For convenience, we pre-extract all the features including the PQ encoded features. This requires 140 GB of free space._ 1. Download and extract CLEVR images+annotations: @@ -130,7 +140,34 @@ _Note: For convenience, we pre-extract all the features including the PQ encoded - In `pq_encoding_clevr.py`, change the value of `PATH` and `streaming_type` (as either 'iid' or 'qtype') - Train PQ encoder and extract features: `python vqa_experiments/clevr/pq_encoding_clevr.py` +4. Train REMIND + - Edit `data_path` in `vqa_experiments/configs/config_CLEVR_streaming.py` + - Run `./vqa_experiments/run_clevr_experiment.sh` (Set `DATA_ORDER` to either `qtype` or `iid` to define the data order) + ### Training REMIND on TDIUC +1. Download TDIUC + ``` + cd ${TDIUC_PATH} + wget https://kushalkafle.com/data/TDIUC.zip && unzip TDIUC.zip + cd TDIUC && python setup.py --download Y # You may need to change print '' statements to print('') + ``` + +2. Extract question features + - Edit `vqa_experiments/clevr/extract_question_features_tdiuc.py`, changing the `DATA_PATH` variable to point to TDIUC dataset and `GENSEN_PATH` to point to gensen repository and extract features: + `python vqa_experiments/tdiuc/extract_question_features_tdiuc.py` + + - Pre-process the TDIUC questions + Edit `$PATH` variable in `vqa_experiments/clevr/preprocess_tdiuc.py` file, pointing it to the directory where TDIUC was extracted + +3. Extract image features, train PQ encoder and extract encoded features + - Extract image features: `python -u vqa_experiments/tdiuc/extract_image_features_tdiuc.py --path /path/to/TDIUC` + - In `pq_encoding_tdiuc.py`, change the value of `PATH` and `streaming_type` (as either 'iid' or 'qtype') + - Train PQ encoder and extract features: `python vqa_experiments/clevr/pq_encoding_clevr.py` + +4. Train REMIND + - Edit `data_path` in `vqa_experiments/configs/config_TDIUC_streaming.py` + - Run `./vqa_experiments/run_tdiuc_experiment.sh` (Set `DATA_ORDER` to either `qtype` or `iid` to define the data order) + ## Citation If using this code, please cite our paper. diff --git a/vqa_experiments/clevr/pq_encoding_clevr.py b/vqa_experiments/clevr/pq_encoding_clevr.py index 114dc72..ac1e57a 100644 --- a/vqa_experiments/clevr/pq_encoding_clevr.py +++ b/vqa_experiments/clevr/pq_encoding_clevr.py @@ -6,11 +6,13 @@ import h5py import json -PATH = '/hdd/robik/CLEVR' +# Change these based on data set +PATH = '/hdd/robik/CLEVR' # Change this +streaming_type = 'iid' # Change this + feat_name = f'{PATH}/all_clevr_resnet_largestage3' train_filename = f'{PATH}/train_clevr.h5' lut_name = f'{PATH}/map_clevr_resnet_largestage3.json' -streaming_type = 'iid' feat_dim = 1024 num_feat_maps = 196 diff --git a/vqa_experiments/clevr/preprocess_clevr.py b/vqa_experiments/clevr/preprocess_clevr.py index 3991823..0b7b8b4 100644 --- a/vqa_experiments/clevr/preprocess_clevr.py +++ b/vqa_experiments/clevr/preprocess_clevr.py @@ -55,7 +55,7 @@ most_common = Counter(meta[m]).most_common() lut[f'{m}2idx'] = {a[0]: idx for idx, a in enumerate(most_common)} -json.dump(lut, open('LUT_clevr.json', 'w')) +json.dump(lut, open(f'{LUT_tdiuc}/LUT_clevr.json', 'w')) # %% dt = h5py.special_dtype(vlen=str) for split in ['train', 'val']: diff --git a/vqa_experiments/configs/config_CLEVR_streaming.py b/vqa_experiments/configs/config_CLEVR_streaming.py index d69c990..595ef67 100644 --- a/vqa_experiments/configs/config_CLEVR_streaming.py +++ b/vqa_experiments/configs/config_CLEVR_streaming.py @@ -32,8 +32,6 @@ test_on = 'full' # 'full' or 'valid' arrangement = dict() -arrangement['train'] = 'random' # 'random', 'aidx', 'atypeidx', 'qtypeidx' -arrangement['val'] = 'random' # 'random', 'aidx', 'atypeidx', 'qtypeidx' # How many to train/test on # How many of "indices" to train on, E.g., if arrangement is ans_class, it refers @@ -83,7 +81,7 @@ num_hidden = 1024 use_model = s_mac.sMacNetwork # BLAH optimizer = torch.optim.Adamax -lr = 1e-4 +lr = 3e-4 save_models = False if not soft_targets: train_on = 'valid' diff --git a/vqa_experiments/configs/config_TDIUC_streaming.py b/vqa_experiments/configs/config_TDIUC_streaming.py new file mode 100644 index 0000000..f4da7b0 --- /dev/null +++ b/vqa_experiments/configs/config_TDIUC_streaming.py @@ -0,0 +1,81 @@ +""" +Written by Kushal, modified by Robik +""" +import vqa_experiments.vqa_models as vqa_models +import torch +from vqa_experiments.dictionary import Dictionary +import sys +from vqa_experiments.vqa_models import WordEmbedding +from vqa_experiments.s_mac import s_mac + +# Model and runtime configuration choices. A copy of this will be saved along with the model +# weights so that it is easy to reproduce later. + +# Data +data_path = '/hdd/robik/TDIUC' +dataset = 'tdiuc' +img_feat = 'resnetpq_iid' # updn, resnet, updnmkii, resnetmkii +mkii = False # If you want to also load codebook indices +data_subset = 1.0 +d = Dictionary.load_from_file(f'vqa_experiments/data/dictionary_{dataset}.pkl') + +map_path = f'{data_path}/map_tdiuc_resnet.json' + +train_file = f'{data_path}/train_{dataset}.h5' +val_file = f'{data_path}/val_{dataset}.h5' + +train_batch_size = 512 +val_batch_size = 512 +num_classes = 1480 # Number of classifier units 1480 for TDIUC, 31xx for VQA,28 for CLEVR + +train_on = 'full' +test_on = 'full' # 'full' or 'valid' + +arrangement = dict() + +only_first_k = dict() +only_first_k['train'] = sys.maxsize # Use sys.maxsize to load all +only_first_k['val'] = sys.maxsize # Use sys.maxsize to load all + +qnorm = True # Normalize ques feat? +imnorm = True # Normalize img feat? + +shuffle = False + +fetch_all = False + +if fetch_all: # For ques_type, ans_class or ans_type arrangement, get all qualifying data + assert (not shuffle) + train_batch_size = 1 + val_batch_size = 1 # Dataset[i] will return all qualifying data of idx 1 + +load_in_memory = False +use_all = False +use_pooled = False +use_lstm = True + +# Training +overwrite_expt_dir = True # Set to True during dev phase +max_epochs = 20 +test_interval = 8 + +# Model +attn_type = 'old' # new or old +num_attn_hops = 2 +soft_targets = False +bidirectional = True +lstm_out = 512 +emb_dim = 300 +cnn_feat_size = 2048 # 2048 for resnet/updn/clevr_layer4 ; 1024 for clevr layer_3 + +classfier_dropout = True +embedding_dropout = True +attention_dropout = True +num_hidden = 1024 +use_model = vqa_models.UpDown # BLAH +optimizer = torch.optim.Adamax +lr = 2e-3 +save_models = False +if not soft_targets: + train_on = 'valid' +num_rehearsal_samples = 50 diff --git a/vqa_experiments/metric.py b/vqa_experiments/metric.py index fd5d93b..35fb7bc 100644 --- a/vqa_experiments/metric.py +++ b/vqa_experiments/metric.py @@ -66,7 +66,41 @@ def compute_clevr_per_type_accuracies(path, preds): print(some_qids) +def compute_tdiuc_accuracy(PATH, preds): + gt_answers = h5py.File(f'{PATH}/val_tdiuc.h5')['aidx'][:] + gt_qids = h5py.File(f'{PATH}/val_tdiuc.h5')['qid'][:] + gt_qtypes = h5py.File(f'{PATH}/val_tdiuc.h5')['qtypeidx'][:] + + qid2qtype = {qid: gt for qid, gt in zip(gt_qids, gt_qtypes)} + qid2gt = {qid: gt for qid, gt in zip(gt_qids, gt_answers)} + + acc = defaultdict(list) + + for qid in qid2gt: + gt = qid2gt[qid] + qtype = qid2qtype[qid] + if gt == preds[str(qid)]: + acc['overall'].append(1) + acc[qtype].append(1) + else: + acc['overall'].append(0) + acc[qtype].append(0) + + mpt = 0 + overall = 0 + for k in acc: + if k == 'overall': + overall = sum(acc[k]) / len(acc[k]) + else: + mpt += sum(acc[k]) / len(acc[k]) + mpt = mpt / 12 + + return mpt, overall + + def compute_accuracy(path, dataset, preds): if dataset == 'clevr': mpt, overall = compute_clevr_accuracy(path, preds) + elif dataset == 'tdiuc': + mpt, overall = compute_tdiuc_accuracy(path, preds) print(f"Mean Per Type: {mpt}, Overall: {overall}") diff --git a/vqa_experiments/run_clevr_experiment.sh b/vqa_experiments/run_clevr_experiment.sh index 5673af2..1dde09b 100644 --- a/vqa_experiments/run_clevr_experiment.sh +++ b/vqa_experiments/run_clevr_experiment.sh @@ -10,7 +10,7 @@ export PYTHONPATH=/hdd/robik/projects/REMIND #--expt_name ${expt} \ #--stream_with_rehearsal \ #--lr ${lr} &> logs/${expt}.log & -DATA_ORDER=iid # or qtype +DATA_ORDER=qtype # or qtype expt=${CONFIG}_${DATA_ORDER}_${lr} CUDA_VISIBLE_DEVICES=0 python -u vqa_experiments/vqa_trainer.py \ diff --git a/vqa_experiments/run_tdiuc_experiment.sh b/vqa_experiments/run_tdiuc_experiment.sh new file mode 100644 index 0000000..44ed6dd --- /dev/null +++ b/vqa_experiments/run_tdiuc_experiment.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +#source activate remind_proj + +lr=2e-3 +CONFIG=TDIUC_streaming +export PYTHONPATH=/hdd/robik/projects/REMIND + +#DATA_ORDER=iid +#expt=${CONFIG}_${DATA_ORDER}_${lr} + +#CUDA_VISIBLE_DEVICES=0 nohup python -u vqa_experiments/vqa_trainer.py \ +#--config_name ${CONFIG} \ +#--expt_name ${expt} \ +#--stream_with_rehearsal \ +#--data_order ${DATA_ORDER} \ +#--lr ${lr} &> logs/${expt}.log & + +DATA_ORDER=qtype # or qtype +expt=${CONFIG}_${DATA_ORDER}_${lr} + +CUDA_VISIBLE_DEVICES=0 python -u vqa_experiments/vqa_trainer.py \ +--config_name ${CONFIG} \ +--expt_name ${expt} \ +--stream_with_rehearsal \ +--data_order ${DATA_ORDER} \ +--lr ${lr} diff --git a/vqa_experiments/tdiuc/extract_image_features_tdiuc.py b/vqa_experiments/tdiuc/extract_image_features_tdiuc.py new file mode 100644 index 0000000..d9a1f6e --- /dev/null +++ b/vqa_experiments/tdiuc/extract_image_features_tdiuc.py @@ -0,0 +1,122 @@ +import argparse, os +import h5py +import numpy as np +from scipy.misc import imread, imresize + +import torch +import torchvision +import json + +parser = argparse.ArgumentParser() +parser.add_argument('--path', type=str, default='/hdd/robik/TDIUC') +parser.add_argument('--max_images', default=None, type=int) + +parser.add_argument('--image_height', default=224, type=int) +parser.add_argument('--image_width', default=224, type=int) + +parser.add_argument('--model', default='resnet152') +parser.add_argument('--model_stage', default=4, type=int) +parser.add_argument('--batch_size', default=128, type=int) + + +def build_model(args): + if not hasattr(torchvision.models, args.model): + raise ValueError('Invalid model "%s"' % args.model) + if not 'resnet' in args.model: + raise ValueError('Feature extraction only supports ResNets') + cnn = getattr(torchvision.models, args.model)(pretrained=True) + layers = [ + cnn.conv1, + cnn.bn1, + cnn.relu, + cnn.maxpool, + ] + for i in range(args.model_stage): + name = 'layer%d' % (i + 1) + layers.append(getattr(cnn, name)) + model = torch.nn.Sequential(*layers) + model.cuda() + model.eval() + return model + + +def run_batch(cur_batch, model): + mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) + std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) + + image_batch = np.concatenate(cur_batch, 0).astype(np.float32) + image_batch = (image_batch / 255.0 - mean) / std + image_batch = torch.FloatTensor(image_batch).cuda() + with torch.no_grad(): + feats = model(image_batch) + feats = feats.data.cpu().clone().numpy() + return feats + + +def path2iid(path): + return int(path.split('/')[-1].split('.')[0].split('_')[-1]) + + +def main(args): + args.output_h5_file = args.path + "/all_tdiuc_resnet.h5" + p1 = f'{args.path}/Images/train2014' + input_paths = [os.path.join(p1, a) for a in os.listdir(p1)] + + p1 = f'{args.path}/Images/val2014' + input_paths.extend([os.path.join(p1, a) for a in os.listdir(p1)]) + + model = build_model(args) + img_size = (args.image_height, args.image_width) + with h5py.File(args.output_h5_file, 'w') as f: + feat_dset = None + i0 = 0 + cur_batch = [] + iid = [] + for i, path in enumerate(input_paths): + iid.append(path2iid(path)) + img = imread(path, mode='RGB') + img = imresize(img, img_size, interp='bicubic') + img = img.transpose(2, 0, 1)[None] + cur_batch.append(img) + if len(cur_batch) == args.batch_size: + feats = run_batch(cur_batch, model) + if feat_dset is None: + N = len(input_paths) + _, C, H, W = feats.shape + feat_dset = f.create_dataset('image_features', (N, H * W, C), + dtype=np.float32) + iid_dset = f.create_dataset('iids', (N,), + dtype=np.int64) + + i1 = i0 + len(cur_batch) + feats_r = feats.reshape(-1, 2048, 49) + feat_dset[i0:i1] = np.transpose(feats_r, (0, 2, 1)) + i0 = i1 + print('Processed %d / %d images' % (i1, len(input_paths))) + cur_batch = [] + + if len(cur_batch) > 0: + feats = run_batch(cur_batch, model) + feats_r = feats.reshape(-1, 2048, 49) + i1 = i0 + len(cur_batch) + feat_dset[i0:i1] = np.transpose(feats_r, (0, 2, 1)) + print('Processed %d / %d images' % (i1, len(input_paths))) + iid_dset[:len(iid)] = np.array(iid, dtype=np.int64) + + feat_file = h5py.File(args.output_h5_file, 'r') + + iid_list = feat_file['iids'][:] + + iid2idx = {str(iid): idx for idx, iid in enumerate(iid_list)} + idx2iid = {idx: str(iid) for idx, iid in enumerate(iid_list)} + + lut = dict() + lut['image_id_to_ix'] = iid2idx + lut['image_ix_to_id'] = idx2iid + + json.dump(lut, open(f'{args.path}/map_tdiuc_resnet.json', 'w')) + + +if __name__ == '__main__': + args = parser.parse_args() + main(args) diff --git a/vqa_experiments/tdiuc/extract_question_features_tdiuc.py b/vqa_experiments/tdiuc/extract_question_features_tdiuc.py new file mode 100644 index 0000000..08b76e4 --- /dev/null +++ b/vqa_experiments/tdiuc/extract_question_features_tdiuc.py @@ -0,0 +1,49 @@ +""" +Written by Kushal, modified by Robik +""" + +import sys +import json +import h5py +import numpy as np + +DATA_PATH = '/hdd/robik/TDIUC' +GENSEN_PATH = '/hdd/robik/projects/gensen' +sys.path.append(GENSEN_PATH) + +from gensen import GenSen, GenSenSingle + +gensen_1 = GenSenSingle( + model_folder=f'{GENSEN_PATH}/data/models', + filename_prefix='nli_large_bothskip', + cuda=True, + pretrained_emb=f'{GENSEN_PATH}/data/embedding/glove.840B.300d.h5' +) + +for split in ['train2014', 'val2014']: + feat_h5 = h5py.File(f'{DATA_PATH}/questions_{split}_tdiuc.h5', 'w') + ques = json.load(open(f'{DATA_PATH}/Questions/OpenEnded_mscoco_{split}_questions.json')) + ques = ques['questions'] + questions = [q['question'] for q in ques] + qids = [q['question_id'] for q in ques] + qids = np.int64(qids) + dt = h5py.special_dtype(vlen=str) + feat_h5.create_dataset('feats', (len(qids), 2048), dtype=np.float32) + feat_h5.create_dataset('qids', (len(qids),), dtype=np.int64) + feat_h5.create_dataset('questions', (len(qids),), dtype=dt) + feat_h5['qids'][:] = qids + feat_h5['questions'][:] = questions + + chunksize = 5000 + question_chunks = [questions[x:x + chunksize] for x in range(0, len(questions), chunksize)] + + done = 0 + for qchunk in question_chunks: + print(done) + _, reps_h_t = gensen_1.get_representation( + qchunk, pool='last', return_numpy=True, tokenize=True + ) + feat_h5['feats'][done:done + len(qchunk)] = reps_h_t + done += len(qchunk) + + feat_h5.close() diff --git a/vqa_experiments/tdiuc/pq_encoding_tdiuc.py b/vqa_experiments/tdiuc/pq_encoding_tdiuc.py new file mode 100644 index 0000000..6bfd60c --- /dev/null +++ b/vqa_experiments/tdiuc/pq_encoding_tdiuc.py @@ -0,0 +1,69 @@ +""" +Written by Kushal, modified by Robik +""" +import faiss +import numpy as np +import h5py +import json + +print("Starting the encoding process...") +# Change these based on data set +PATH = '/hdd/robik/TDIUC' # Change this +streaming_type = 'qtype' # Change this + +# Probably don't need to be changed +feat_name = f'{PATH}/all_tdiuc_resnet' +train_filename = f'{PATH}/train_tdiuc.h5' +lut_name = f'{PATH}/map_tdiuc_resnet.json' + +feat_dim = 2048 +num_feat_maps = 49 + +train_data = h5py.File(train_filename, 'r') +lut = json.load(open(lut_name)) +feat_h5 = h5py.File(f'{feat_name}.h5', 'r') + +if streaming_type == 'iid': + ids = train_data['iid'][:] + feat_idxs = list(set([lut['image_id_to_ix'][str(id)] for id in ids])) + feat_idxs_base_init = feat_idxs[:int(0.1 * len(feat_idxs))] +else: + feat_idxs_base_init = list(set([lut['image_id_to_ix'][str(iid)] for qtypeidx, iid in + zip(train_data['qtypeidx'], train_data['iid']) if qtypeidx == 0])) + +print(f"# samples for base init {len(feat_idxs_base_init)}") +train_data_base_init = np.array([feat_h5['image_features'][bidx] for bidx in feat_idxs_base_init], dtype=np.float32) + +# train set + +train_data_base_init = np.reshape(train_data_base_init, (-1, feat_dim)) + +print('Training Product Quantizer') + +d = feat_dim # data dimension +cs = 32 # code size (bytes) +pq = faiss.ProductQuantizer(d, cs, 8) +pq.train(train_data_base_init) + +print('Encoding, Decoding and saving Reconstructed Features') + +feats = feat_h5['image_features'] +start = 0 +batch = 10000 +reconstructed_h5 = h5py.File(f'{feat_name}pq_{streaming_type}.h5', 'w') +reconstructed_h5.create_dataset('image_features', shape=feats.shape, dtype=np.float32) + +while start < len(feats): + print(start, ' feats done out of ', len(feats)) + data_batch = feats[start:start + batch] + num_feats = len(data_batch) + data_batch = np.reshape(data_batch, (-1, feat_dim)) + codes = pq.compute_codes(data_batch) + data_batch_reconstructed = pq.decode(codes) + data_batch_reconstructed = np.reshape(data_batch_reconstructed, (-1, num_feat_maps, feat_dim)) + reconstructed_h5['image_features'][start:start + num_feats] = data_batch_reconstructed + start = start + batch + +reconstructed_h5.close() + +# Boundary points: [21602, 463412, 575269, 821512, 840988, 903850, 905311, 905661, 950335, 976377, 982225, 1115299] diff --git a/vqa_experiments/tdiuc/preprocess_tdiuc.py b/vqa_experiments/tdiuc/preprocess_tdiuc.py new file mode 100644 index 0000000..30d4cdb --- /dev/null +++ b/vqa_experiments/tdiuc/preprocess_tdiuc.py @@ -0,0 +1,76 @@ +""" +Written by Kushal, modified by Robik +""" + +import json +import h5py +import numpy as np +from collections import Counter, defaultdict +from tqdm import tqdm + +PATH = '/hdd/robik/TDIUC' +annotations = dict() +for split in ['train', 'val']: + annotations[split] = json.load( + open(f'{PATH}/Annotations/mscoco_{split}2014_annotations.json'))['annotations'] + +meta = defaultdict(list) + +for ann in annotations['train']: + ten_ans = [a['answer'] for a in ann['answers']] * 10 + ans = ten_ans[0] + meta['a'].append(ans) + meta['atype'].append('answer_type') + meta['qtype'].append(ann['question_type']) + +lut = dict() + +for m in ['a', 'atype', 'qtype']: + most_common = Counter(meta[m]).most_common() + lut[f'{m}2idx'] = {a[0]: idx for idx, a in enumerate(most_common)} + +json.dump(lut, open(f'{PATH}/LUT_tdiuc.json', 'w')) +# %% +dt = h5py.special_dtype(vlen=str) +for split in ['train', 'val']: + qfeat_file = h5py.File(f'{PATH}/questions_{split}2014_tdiuc.h5', 'r') + + mem_feat = dict() + for dset in qfeat_file.keys(): + mem_feat[dset] = qfeat_file[dset][:] + qids = mem_feat['qids'][:] + qid2idx = {qid: idx for idx, qid in enumerate(qids)} + num_instances = len(annotations[split]) + h5file = h5py.File(f'{PATH}/{split}_tdiuc.h5', 'w') + h5file.create_dataset('qfeat', (num_instances, 2048), dtype=np.float32) + h5file.create_dataset('qid', (num_instances,), dtype=np.int64) + h5file.create_dataset('iid', (num_instances,), dtype=np.int64) + h5file.create_dataset('q', (num_instances,), dtype=dt) + h5file.create_dataset('a', (num_instances,), dtype=dt) + h5file.create_dataset('ten_ans', (num_instances, 10), dtype=dt) + h5file.create_dataset('aidx', (num_instances,), dtype=np.int32) + h5file.create_dataset('ten_aidx', (num_instances, 10), dtype=np.int32) + h5file.create_dataset('atypeidx', (num_instances,), dtype=np.int32) + h5file.create_dataset('qtypeidx', (num_instances,), dtype=np.int32) + + for idx, ann in enumerate(tqdm(annotations[split])): + qid = ann['question_id'] + iid = ann['image_id'] + feat_idx = qid2idx[qid] + ten_ans = [a['answer'] for a in ann['answers']] * 10 + ans = ten_ans[0] + aidx = lut['a2idx'].get(ans, -1) + ten_aidx = np.array([lut['a2idx'].get(a, -1) for a in ten_ans]) + atypeidx = lut['atype2idx'].get('answer_type', -1) + qtypeidx = lut['qtype2idx'].get(ann['question_type'], -1) + h5file['qfeat'][idx] = mem_feat['feats'][feat_idx] + h5file['qid'][idx] = qid + h5file['iid'][idx] = iid + h5file['q'][idx] = mem_feat['questions'][feat_idx] + h5file['a'][idx] = ans + h5file['ten_ans'][idx] = ten_ans + h5file['aidx'][idx] = aidx + h5file['atypeidx'][idx] = atypeidx + h5file['qtypeidx'][idx] = qtypeidx + h5file['ten_aidx'][idx] = ten_aidx + h5file.close() diff --git a/vqa_experiments/vqa_trainer.py b/vqa_experiments/vqa_trainer.py index ef67cdf..38ebde9 100644 --- a/vqa_experiments/vqa_trainer.py +++ b/vqa_experiments/vqa_trainer.py @@ -379,12 +379,18 @@ def exponential_averaging(model1, model2, decay=0.999): # %% def main(): - config.image_stage = f'largestage3pq_{args.data_order}' # stage3 or stage4 if clevr is used - - config.feat_path = f'{config.data_path}/all_clevr_resnet_{config.image_stage}.h5' + if config.dataset == 'clevr': + config.feat_path = f'{config.data_path}/all_clevr_resnet_largestage3pq_{args.data_order}.h5' + else: + config.feat_path = f'{config.data_path}/all_tdiuc_resnetpq_{args.data_order}.h5' config.expt_dir = 'snapshots/' + args.expt_name config.use_exponential_averaging = args.use_exponential_averaging + config.data_order = args.data_order + if config.data_order == 'iid': + config.arrangement = {'train': 'random', 'val': 'random'} + else: + config.arrangement = {'train': 'qtypeidx', 'val': 'qtypeidx'} if not config.overwrite_expt_dir: assert_expt_name_not_present( config.expt_dir) # Just comment it out during dev phase, otherwise it can get annoying @@ -458,6 +464,7 @@ def main(): if config.use_exponential_averaging: exponential_averaging(net_running, net, 0) + print(json.dumps(args.__dict__, indent=4, sort_keys=True)) shutil.copy('vqa_experiments/configs/config_' + args.config_name + '.py', os.path.join(config.expt_dir, 'config_' + args.config_name + '.py'))