diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 039453af7..7d6d4ecb6 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -147,11 +147,15 @@ def encode_text(self, text, normalize=True, embed_cls=True): text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) return text_latent - def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): - text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + def forward(self, image, text=None, embed_cls=True, image_latent=None, image_embs=None): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) + if text is None: + return {"image_features": image_latent, "image_embs": image_embs} + + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 0a30e9466..610e8b952 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -586,7 +586,7 @@ def build_attention_mask(self): def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + cls_mask = F.pad(cls_mask, (0, 1, cls_mask.shape[2], 0), value=1.0) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) @@ -673,23 +673,28 @@ def __init__( self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + self.init_parameters() + def init_parameters(self): - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - for block in self.transformer.cross_attn: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5) + # attn_std = self.width ** -0.5 + # fc_std = (2 * self.width) ** -0.5 + # for block in self.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # for block in self.cross_attn: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + # nn.init.normal_(self.text_projection, std=self.width ** -0.5) + nn.init.zeros_(self.text_projection) + # nn.init.kaiming_uniform_(self.text_projection, a=math.sqrt(5)) # nn.Linear default + def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens diff --git a/tests/test_inference.py b/tests/test_inference.py index dca8dc44c..63dcfddb5 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -28,7 +28,7 @@ 'ViT-e-14', 'mt5-xl-ViT-H-14', 'coca_base', - 'coca_ViT-B-32', + # 'coca_ViT-B-32', 'coca_roberta-ViT-B-32' }) @@ -60,14 +60,6 @@ def test_inference_with_data( force_quick_gelu = False, ): util_test.seed_all() - model, _, preprocess_val = open_clip.create_model_and_transforms( - model_name, - pretrained = pretrained, - precision = precision, - jit = jit, - force_quick_gelu = force_quick_gelu, - pretrained_hf = pretrained_hf - ) model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}' input_dir, output_dir = util_test.get_data_dirs() # text @@ -77,6 +69,14 @@ def test_inference_with_data( pytest.skip(reason = f"missing test data, expected at {input_text_path}") if not os.path.isfile(gt_text_path): pytest.skip(reason = f"missing test data, expected at {gt_text_path}") + model, _, preprocess_val = open_clip.create_model_and_transforms( + model_name, + pretrained = pretrained, + precision = precision, + jit = jit, + force_quick_gelu = force_quick_gelu, + pretrained_hf = pretrained_hf + ) input_text = torch.load(input_text_path) gt_text = torch.load(gt_text_path) y_text = util_test.inference_text(model, model_name, input_text)