-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_scene.py
126 lines (103 loc) · 4.8 KB
/
test_scene.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
""" Derived from [NeuralRecon](https://github.com/zju3dv/NeuralRecon) by Jiaming Sun and Yiming Xie. """
import argparse
import os, time
import torch
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
from models import VisFusion
from utils import SaveScene
from config import cfg, update_config
from datasets import find_dataset_def, transforms
from tools.process_arkit_data import process_data
parser = argparse.ArgumentParser(description='Running VisFusion')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
# parse arguments and check
args = parser.parse_args()
update_config(cfg, args)
if cfg.DATASET == 'ARKit':
if not os.path.exists(os.path.join(cfg.TEST.PATH, 'SyncedPoses.txt')):
logger.info("First run on this captured data, start the pre-processing...")
process_data(cfg.TEST.PATH)
else:
logger.info("Found SyncedPoses.txt, skipping data pre-processing...")
logger.info("Running VisFusion...")
transform = [transforms.ResizeImage((640, 480)),
transforms.ToTensor(),
transforms.RandomTransformSpace(
cfg.MODEL.N_VOX, cfg.MODEL.VOXEL_SIZE, random_rotation=False, random_translation=False,
paddingXY=0, paddingZ=0, max_epoch=cfg.TRAIN.EPOCHS),
transforms.IntrinsicsPoseToProjection(cfg.TEST.N_VIEWS, 4)]
transforms = transforms.Compose(transform)
Dataset = find_dataset_def(cfg.DATASET)
test_dataset = Dataset(cfg.TEST.PATH, "test", transforms, cfg.TEST.N_VIEWS, len(cfg.MODEL.THRESHOLDS) - 1, cfg.SCENE,
load_gt=False)
data_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=cfg.TEST.N_WORKERS, drop_last=False)
# model
logger.info("Initializing the model on GPU...")
model = VisFusion(cfg).cuda().eval()
model = torch.nn.DataParallel(model, device_ids=[0])
if cfg.LOADCKPT != '':
saved_models = [cfg.LOADCKPT]
else:
# use the latest checkpoint file
saved_models = [fn for fn in os.listdir(cfg.LOGDIR) if fn.endswith(".ckpt")]
saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))
loadckpt = os.path.join(cfg.LOGDIR, saved_models[-1])
logger.info("Resuming from " + str(loadckpt))
state_dict = torch.load(loadckpt)
model.load_state_dict(state_dict['model'], strict=False)
epoch_idx = state_dict['epoch']
save_mesh_scene = SaveScene(cfg)
logger.info("Start inference..")
duration = 0.
gpu_mem_usage = []
frag_len = len(data_loader)
with torch.no_grad():
for frag_idx, sample in enumerate(tqdm(data_loader)):
# save mesh if: 1. SAVE_SCENE_MESH and is the last fragment, or
# 2. SAVE_INCREMENTAL, or
# 3. VIS_INCREMENTAL
save_scene = (cfg.SAVE_SCENE_MESH and frag_idx == frag_len - 1) or cfg.SAVE_INCREMENTAL or cfg.VIS_INCREMENTAL
start_time = time.time()
outputs, loss_dict = model(sample, save_scene)
duration += time.time() - start_time
if cfg.REDUCE_GPU_MEM:
# will slow down the inference
torch.cuda.empty_cache()
# vis or save incremental result.
scene = sample['scene'][0]
save_mesh_scene.keyframe_id = frag_idx
pre_scene = save_mesh_scene.scene_name
save_mesh_scene.scene_name = scene.replace('/', '-')
if cfg.SAVE_INCREMENTAL:
save_video = ((pre_scene is not None) and (pre_scene != scene)) or (frag_idx == frag_len - 1)
save_mesh_scene.save_incremental(epoch_idx, 0, sample, outputs, save_video=save_video)
if cfg.VIS_INCREMENTAL:
save_mesh_scene.vis_incremental(epoch_idx, 0, sample['imgs'][0], outputs)
if cfg.SAVE_SCENE_MESH and frag_idx == frag_len - 1:
assert 'scene_tsdf' in outputs, \
"""Reconstruction failed. Potential reasons could be:
1. Wrong camera poses.
2. Extremely difficult scene.
If you can run with the demo data without any problem, please submit a issue with the failed data attatched, thanks!
"""
save_mesh_scene.save_scene_eval(epoch_idx, outputs)
gpu_mem_usage.append(torch.cuda.memory_reserved())
summary_text = f"""
Summary:
Total number of fragments: {frag_len}
Average keyframes/sec: {1 / (duration / (frag_len * cfg.TEST.N_VIEWS))}
Average GPU memory usage (reserved) (GB): {sum(gpu_mem_usage) / len(gpu_mem_usage) / (1024 ** 3)}
Max GPU memory usage (reserved) (GB): {max(gpu_mem_usage) / (1024 ** 3)}
"""
print(summary_text)
if cfg.VIS_INCREMENTAL:
save_mesh_scene.close()