-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_planercnn.py
executable file
·68 lines (53 loc) · 2.61 KB
/
train_planercnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import cv2
cv2.setNumThreads(0)
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
# from pytorch_lightning.profiler import PyTorchProfiler
from models.model import *
from datasets.scenenet_rgbd_stereo_dataset import *
from options import parse_args
from config import PlaneConfig
def train(options):
config = PlaneConfig(options)
if options.no_normals:
dataset = ScenenetRgbdDataset(options,
config,
split='train',
random=False,
load_normals=False)
else:
dataset = ScenenetRgbdDataset(options, config, split='train', random=False)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
if options.no_normals:
dataset_test = ScenenetRgbdDataset(options,
config,
split='test',
random=False,
load_normals=False)
else:
dataset_test = ScenenetRgbdDataset(options, config, split='test', random=False)
test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4)
print('num_plane_ids: %d' % dataset.get_num_plane_ids())
options.num_plane_ids = dataset.get_num_plane_ids()
# profiler = PyTorchProfiler(use_cuda=True, profile_memory=True, record_shapes=True)
# profiler = SimpleProfiler()
model = MaskRCNN(config, options)
model.load_weights(['/mnt/data/datasets/JW/scenenet_rgbd/checkpoint/mask_rcnn_coco.pth',
'/mnt/data/datasets/JW/scenenet_rgbd/checkpoint/mvdnet_scannet.pth'])
# model = MaskRCNN.load_from_checkpoint('/mnt/data/datasets/JW/scenenet_rgbd/checkpoint/mask_rcnn_mvdnet.ckpt',
# config=config,
# options=options,
# detect=True,
# strict=False)
if options.checkpoint == '':
trainer = pl.Trainer(gpus=1, max_epochs=options.numEpochs)
else:
trainer = pl.Trainer(gpus=1, max_epochs=options.numEpochs,
resume_from_checkpoint=options.checkpoint)
trainer.fit(model, train_loader, test_loader)
# print(profiler.key_averages().table(sort_by="self_gpu_memory_usage", row_limit=10))
if __name__ == '__main__':
args = parse_args()
print('task=%s started' % (args.task))
train(args)