-
Notifications
You must be signed in to change notification settings - Fork 4
/
kfold_train.py
83 lines (61 loc) · 2.94 KB
/
kfold_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import datetime
from random import random
import torch
import os
import numpy as np
from sklearn.model_selection import KFold
from detectron2.data.catalog import DatasetCatalog
from detectron2.data.datasets.coco import load_coco_json
from detectron2.utils import logger
from maskformer2 import init_mask2former, train_mask2former
from standard_maskrcnn import init_maskrcnn, train_maskrcnn
from rotated_maskrcnn import init_rotated_maskrcnn, train_rotated_maskrcnn
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
#########################
### PROGRAM VARIABLES ###
#########################
MODEL_NAME = "maskrcnn" # "maskrcnn", "maskrcnn-rotated", "mask2former"
N_FOLDS = 5 # Number of folds for cross-validation
DATASET_FILENAMES = {
"mask2former": 'data/prescaled/coco_annotation.json',
"maskrcnn": 'data/prescaled/coco_annotation.json',
"maskrcnn-rotated": 'data/prescaled/coco_annotation_rotated.json'
}
CONFIG_FILES = {
"mask2former": "configs/config_mask2former_swinB.yaml",
"maskrcnn": "configs/config_standard_maskrcnn.yaml",
"maskrcnn-rotated": "configs/config_rotated_maskrcnn.yaml"
}
IMAGE_DIR = 'data/prescaled' # Folder where data is located
INITIAL_WEIGHTS = None # Path to a previous checkpoint to finetune or None
#########################
def kfold_train(n_folds, model_name, dataset_dicts, config_file):
# Init output folder
base_output_dir = f'./outputs/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")}'
random_state = np.random.RandomState(42)
kf = KFold(n_splits=n_folds, shuffle=True, random_state=random_state)
fold_num = 0
for train_index, test_index in kf.split(dataset_dicts): # For each fold
# Add fold index to output_dir
output_dir = os.path.join(base_output_dir, f"fold_{fold_num}")
train_dicts = dataset_dicts[train_index]
test_dicts = dataset_dicts[test_index]
if model_name == "mask2former":
cfg = init_mask2former(config_file, train_dicts, test_dicts, output_dir, fold_num)
train_mask2former(cfg)
elif model_name == "maskrcnn":
cfg = init_maskrcnn(config_file, train_dicts, test_dicts, output_dir, fold_num)
train_maskrcnn(cfg)
elif model_name == "maskrcnn-rotated":
cfg = init_rotated_maskrcnn(config_file, train_dicts, test_dicts, output_dir, fold_num)
train_rotated_maskrcnn(cfg)
fold_num += 1
DatasetCatalog.remove(cfg.DATASETS.TRAIN[0])
DatasetCatalog.remove(cfg.DATASETS.TEST[0])
if __name__ == "__main__":
print('GPU available :', torch.cuda.is_available())
print('Torch version :', torch.__version__, '\n')
logger.setup_logger(name=__name__)
dataset_dicts = np.array(load_coco_json(DATASET_FILENAMES[MODEL_NAME], IMAGE_DIR))
kfold_train(N_FOLDS, MODEL_NAME, dataset_dicts, CONFIG_FILES[MODEL_NAME])