forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cae_beit-base-p16_8xb256-amp-coslr-300e_in1k.py
115 lines (105 loc) · 3.06 KB
/
cae_beit-base-p16_8xb256-amp-coslr-300e_in1k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
_base_ = '../_base_/default_runtime.py'
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
second_mean=[-31.875, -31.875, -31.875],
second_std=[318.75, 318.75, 318.75],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=112,
interpolation='bicubic',
second_interpolation='lanczos',
scale=(0.08, 1.0)),
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=None,
min_num_patches=16),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
# model settings
model = dict(
type='CAE',
backbone=dict(
type='CAEPretrainViT',
arch='b',
patch_size=16,
layer_scale_init_value=0.1,
bias='qv_bias'),
neck=dict(
type='CAENeck',
embed_dims=768,
num_heads=12,
regressor_depth=4,
decoder_depth=4,
mlp_ratio=4,
layer_scale_init_value=0.1,
),
head=dict(type='CAEHead', loss=dict(type='CAELoss', lambd=2)),
target_generator=dict(
type='DALL-E',
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501
)),
base_momentum=0.0)
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW', lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05),
clip_grad=dict(max_norm=3.0),
paramwise_cfg=dict(
bias_decay_mult=0.0, norm_decay_mult=0.0, flat_decay_mult=0.0))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=290,
eta_min=1e-5,
by_epoch=True,
begin=10,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=2048)