-
Notifications
You must be signed in to change notification settings - Fork 3
/
oxford_ours.py
119 lines (112 loc) · 2.61 KB
/
oxford_ours.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
_base_ = [
'./base_cfg.py',
'./dataset_cfgs/oxford_cfg.py'
]
task_type = 'ours_me'
optimizer_type = 'Adam'
optimizer_cfg = dict(
lr=2e-4,
weight_decay=0,
betas=(0.9, 0.999),
)
scheduler_type = 'MultiStepLR'
scheduler_cfg = dict(
gamma=0.1,
milestones=(80, 120, 160)
)
end_epoch = 200
train_cfg = dict(
save_per_epoch=10,
val_per_epoch=5,
batch_sampler_type='ExpansionBatchSampler',
batch_sampler_cfg=dict(
max_batch_size=128,
batch_size_expansion_rate=1.4,
batch_expansion_threshold=0.7,
batch_size=32,
shuffle=True,
drop_last=True,
),
num_workers=0,
)
eval_cfg = dict(
batch_sampler_cfg=dict(
batch_size=32,
drop_last=False,
),
num_workers=0,
normalize_embeddings=False,
)
model_type = 'Ours'
model_cfg = dict(
backbone_cfg=dict(
up_conv_cfgs=[
[dict(
in_channels=1,
out_channels=64,
kernel_size=5,
stride=1,
),
dict(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=2,
)],
[dict(
in_channels=1,
out_channels=64,
kernel_size=5,
stride=1,
),
dict(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=2,
)],
[dict(
in_channels=1,
out_channels=64,
kernel_size=5,
stride=1,
),
dict(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=2,
)],
],
transformer_cfg=dict(
num_attn_layers=6,
global_channels=64,
local_channels=0,
num_centers=[256, 128, 128, 64, 32, 32],
num_heads=4,
time_dim=8,
learned_sinusoidal_cond=True,
),
pointnet_cfg=dict(
std_cfg=dict(
conv_channels=[64, 128, 512],
fc_channels=[256, 128]
),
global_feat = True,
channels=[64, 128, 256]
),
in_channels=1,
out_channels=512,
fine_to_coarse = False,
step_size=2,
),
pool_cfg=dict(
type='NetVlad',
in_channels=512,
out_channels=512,
cluster_size=64,
gating=True,
add_bn=True
),
quantization_size=[0.01, 0.12, 0.2],
)