-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathrein_dinov2_mask2former_1024x1024_bs4x2.py
83 lines (82 loc) · 2.44 KB
/
rein_dinov2_mask2former_1024x1024_bs4x2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# dataset config
_base_ = [
"../_base_/datasets/dg_citys2acdc_1024x1024.py",
"../_base_/default_runtime.py",
"../_base_/models/rein_dinov2_mask2former.py",
]
crop_size = (1024, 1024)
model = dict(
backbone=dict(
img_size=1024,
init_cfg=dict(
checkpoint="checkpoints/dinov2_converted_1024x1024.pth",
),
),
data_preprocessor=dict(
size=crop_size,
),
test_cfg=dict(
crop_size=(1024, 1024),
stride=(683, 683),
),
)
train_pipeline = [
dict(type="LoadImageFromFile"),
dict(type="LoadAnnotations"),
dict(
type="RandomChoiceResize",
scales=[int(1024 * x * 0.1) for x in range(5, 21)],
resize_type="ResizeShortestEdge",
max_size=4096,
),
dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75),
dict(type="RandomFlip", prob=0.5),
dict(type="PhotoMetricDistortion"),
dict(type="PackSegInputs"),
]
train_dataloader = dict(batch_size=2, dataset=dict(pipeline=train_pipeline))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optim_wrapper = dict(
constructor="PEFTOptimWrapperConstructor",
optimizer=dict(
type="AdamW", lr=0.00006, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999)
),
paramwise_cfg=dict(
custom_keys={
"norm": dict(decay_mult=0.0),
"query_embed": embed_multi,
"level_embed": embed_multi,
"learnable_tokens": embed_multi,
"reins.scale": embed_multi,
},
norm_decay_mult=0.0,
),
)
param_scheduler = [
dict(type="LinearLR", start_factor=1e-6, by_epoch=False, begin=0, end=10000),
dict(
type="PolyLR",
eta_min=0.0,
power=0.9,
begin=10000,
end=40000,
by_epoch=False,
),
]
# training schedule for 160k
train_cfg = dict(type="IterBasedTrainLoop", max_iters=40000, val_interval=10000)
val_cfg = dict(type="ValLoop")
test_cfg = dict(type="TestLoop")
default_hooks = dict(
timer=dict(type="IterTimerHook"),
logger=dict(type="LoggerHook", interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type="ParamSchedulerHook"),
checkpoint=dict(
type="CheckpointHook", by_epoch=False, interval=4000, max_keep_ckpts=3
),
sampler_seed=dict(type="DistSamplerSeedHook"),
visualization=dict(type="SegVisualizationHook"),
)
find_unused_parameters = True