-
Notifications
You must be signed in to change notification settings - Fork 4
/
test.py
127 lines (110 loc) · 5.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from multiprocessing import Pool
import time
import argparse
import torch
import torch.utils.tensorboard
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from datasets import *
from utils.misc import *
from utils.transforms import *
from models.denoise import *
from models.utils import chamfer_distance_unit_sphere
from Evaluate import Evaluator
def input_iter(input_dir):
for fn in sorted(os.listdir(input_dir)):
if fn[-3:] != 'xyz':
continue
pcl_noisy = torch.FloatTensor(np.loadtxt(os.path.join(input_dir, fn)))
pcl_noisy, center, scale = NormalizeUnitSphere.normalize(pcl_noisy)
yield {
'pcl_noisy': pcl_noisy,
'name': fn[:-4],
'center': center,
'scale': scale
}
def main(noise):
for resolution in args.resolutions:
# Input/Output
input_dir = os.path.join(args.input_root, '%s_%s_%s' % (args.dataset, resolution, noise))
save_title = '{dataset}_Ours{modeltag}_{tag}_{res}_{noise}'.format_map({
'dataset': args.dataset,
'modeltag': '' if args.niters == 1 else '%dx' % args.niters,
'tag': args.tag,
'res': resolution,
'noise': noise
})
output_dir = os.path.join(args.output_root, save_title)
if not os.path.exists(output_dir):
os.makedirs(output_dir) # Output point clouds
logger = get_logger('test_'+args.dataset+'_'+resolution+'_'+noise, output_dir)
for k, v in vars(args).items():
logger.info('[ARGS::%s] %s' % (k, repr(v)))
# Model
model = DenoiseNet.load_from_checkpoint(args.ckpt)
model = model.to(args.device)
# Denoise
for data in input_iter(input_dir):
logger.info(data['name'])
pcl_noisy = data['pcl_noisy'].to(args.device)
with torch.no_grad():
model.eval()
pcl_next = pcl_noisy
for _ in range(args.niters):
if args.patch_stitching:
pcl_next = model.patch_based_denoise(pcl_noisy=pcl_next,
patch_size=args.patch_size,
seed_k=args.seed_k,
seed_k_alpha=args.seed_k_alpha,
num_modules_to_use=args.num_modules_to_use)
elif not args.patch_stitching:
pcl_next = model.patch_based_denoise_without_stitching(pcl_noisy=pcl_next,
patch_size=args.patch_size,
seed_k=args.seed_k,
seed_k_alpha=args.seed_k_alpha,
num_modules_to_use=args.num_modules_to_use)
pcl_denoised = pcl_next.cpu()
# Denormalize
pcl_denoised = pcl_denoised * data['scale'] + data['center']
save_path = os.path.join(output_dir, data['name'] + '.xyz')
np.savetxt(save_path, pcl_denoised.numpy(), fmt='%.8f')
if not args.dataset.startswith('RueMadame'):
# Evaluate
evaluator = Evaluator(
output_pcl_dir=output_dir,
dataset_root=args.dataset_root,
dataset='PUNet' if args.dataset.startswith('PUNet') else args.dataset,
summary_dir=args.output_root,
experiment_name=save_title,
device=args.device,
res_gts=resolution,
logger=logger
)
evaluator.run()
if __name__ == '__main__':
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, default='./pretrained/denoisenet-ep-99.ckpt')
parser.add_argument('--num_modules_to_use', type=int, default=None)
parser.add_argument('--input_root', type=str, default='./data/examples')
parser.add_argument('--output_root', type=str, default='./data/results')
parser.add_argument('--dataset_root', type=str, default='./data')
parser.add_argument('--dataset', type=str, default='PUNet')
parser.add_argument('--tag', type=str, default='')
parser.add_argument('--resolutions', type=str_list, default=['50000_poisson']) # Set your test resolution
parser.add_argument('--noise_lvls', type=str_list, default=['0.01']) # Set your test noise level
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--seed', type=int, default=2020)
# Filtering parameters
parser.add_argument('--patch_size', type=int, default=1000)
parser.add_argument('--niters', type=int, default=1)
parser.add_argument('--denoise_knn', type=int, default=None, help='kNN size to use during testing')
# Patch stitching params
parser.add_argument('--patch_stitching', type=bool, default=True, help='Use patch stitching or not?')
parser.add_argument('--seed_k', type=int, default=6) # 6 for Kinect, 6 for small PCL, 6 for RueMadame PCL
parser.add_argument('--seed_k_alpha', type=int, default=10) # 2 for Kinect, 10 for small PCL, 20 for RueMadame PCL
args = parser.parse_args()
seed_all(args.seed)
with Pool(len(args.noise_lvls)) as p:
p.map(main, args.noise_lvls)