Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoCa: fix MultimodalTransformer init + Mask CLS token at end of seq #551

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,15 @@ def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent

def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
def forward(self, image, text=None, embed_cls=True, image_latent=None, image_embs=None):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

if text is None:
return {"image_features": image_latent, "image_embs": image_embs}

text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

Expand Down
35 changes: 20 additions & 15 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def build_attention_mask(self):

def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
cls_mask = F.pad(cls_mask, (0, 1, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
Expand Down Expand Up @@ -673,23 +673,28 @@ def __init__(
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))

self.init_parameters()

def init_parameters(self):
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
for block in self.transformer.cross_attn:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
# proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5)
# attn_std = self.width ** -0.5
# fc_std = (2 * self.width) ** -0.5
# for block in self.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
# for block in self.cross_attn:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
# nn.init.normal_(self.text_projection, std=self.width ** -0.5)
nn.init.zeros_(self.text_projection)
# nn.init.kaiming_uniform_(self.text_projection, a=math.sqrt(5)) # nn.Linear default


def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
Expand Down
18 changes: 9 additions & 9 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
'ViT-e-14',
'mt5-xl-ViT-H-14',
'coca_base',
'coca_ViT-B-32',
# 'coca_ViT-B-32',
'coca_roberta-ViT-B-32'
})

Expand Down Expand Up @@ -60,14 +60,6 @@ def test_inference_with_data(
force_quick_gelu = False,
):
util_test.seed_all()
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = jit,
force_quick_gelu = force_quick_gelu,
pretrained_hf = pretrained_hf
)
model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
input_dir, output_dir = util_test.get_data_dirs()
# text
Expand All @@ -77,6 +69,14 @@ def test_inference_with_data(
pytest.skip(reason = f"missing test data, expected at {input_text_path}")
if not os.path.isfile(gt_text_path):
pytest.skip(reason = f"missing test data, expected at {gt_text_path}")
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = jit,
force_quick_gelu = force_quick_gelu,
pretrained_hf = pretrained_hf
)
input_text = torch.load(input_text_path)
gt_text = torch.load(gt_text_path)
y_text = util_test.inference_text(model, model_name, input_text)
Expand Down