Skip to content

Commit

Permalink
[Fix]: Set qkv bias to False for cae and True for mae (#303)
Browse files Browse the repository at this point in the history
* [Fix]: Add mmcls transformer layer choice

* [Fix]: Fix transformer encoder layer bug

* [Fix]: Change UT of cae
  • Loading branch information
YuanLiuuuuuu authored May 5, 2022
1 parent 399b5a0 commit 249b7db
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_base_ = 'vit-base-p16_ft-8xb128-coslr-100e_in1k.py'

# model
model = dict(backbone=dict(use_window=True, init_values=0.1))
model = dict(backbone=dict(use_window=True, init_values=0.1, qkv_bias=False))

# optimizer
optimizer = dict(lr=8e-3)
Expand Down
7 changes: 6 additions & 1 deletion configs/selfsup/_base_/models/cae.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# model settings
model = dict(
type='CAE',
backbone=dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1),
backbone=dict(
type='CAEViT',
arch='b',
patch_size=16,
init_values=0.1,
qkv_bias=False),
neck=dict(
type='CAENeck',
patch_size=16,
Expand Down
39 changes: 20 additions & 19 deletions mmselfsup/models/utils/transformer_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,30 @@ def __init__(self,
proj_bias=proj_bias,
init_cfg=init_cfg)

del self.out_drop
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.qkv_bias = qkv_bias

if not self.qkv_bias:
self._init_qv_bias()

self.qkv = nn.Linear(
self.input_dims, embed_dims * 3, bias=self.qkv_bias)

def _init_qv_bias(self):
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))

def forward(self, x: torch.Tensor) -> torch.Tensor:
# qkv bias is different from that in mmcls
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(self.q_bias,
torch.zeros_like(self.v_bias,
requires_grad=False), self.v_bias))
B, N, _ = x.shape
qkv = F.linear(
x, weight=self.qkv.weight,
bias=qkv_bias).reshape(B, N, 3, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)

if not self.qkv_bias:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
else:
qkv = self.qkv(x)

qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

attn = (q @ k.transpose(-2, -1)) * self.scale
Expand Down
3 changes: 2 additions & 1 deletion tests/test_models/test_algorithms/test_cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mmselfsup.models.algorithms import CAE

# model settings
backbone = dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1)
backbone = dict(
type='CAEViT', arch='b', patch_size=16, init_values=0.1, qkv_bias=False)
neck = dict(
type='CAENeck',
patch_size=16,
Expand Down

0 comments on commit 249b7db

Please sign in to comment.