Skip to content

Commit

Permalink
add use_shared_pos_emb_for_attn parameter for beit supporting raw mod…
Browse files Browse the repository at this point in the history
…el w/o any finetuning
  • Loading branch information
leondgarse committed May 29, 2024
1 parent dccf3e7 commit 14a9c07
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions keras_cv_attention_models/beit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
False for `FlexiViT`, same as `no_embed_class` in timm. Default True for others.
use_rot_pos_emb: boolean value if use `PositionalEncodingFourierRot` on attention query and key.
True for EVA02, False for others.
use_shared_pos_emb_for_attn: boolean value if use a shared `MultiHeadRelativePositionalEmbedding` layer for all attention blocks.
True for Beit raw model without any finetune, False for others.
[MLP args]
mlp_ratio: dimension expansion ration for `mlp_block`s. Default `4`.
Expand Down
13 changes: 13 additions & 0 deletions keras_cv_attention_models/beit/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def attention_block(
use_pos_emb=False,
use_rot_pos_emb=False,
qk_rope=None,
shared_pos_emb=None,
attn_height=-1,
text_max_block_size=0, # Also a mark if this is a text inputs
attn_dropout=0,
Expand Down Expand Up @@ -344,6 +345,8 @@ def attention_block(

if is_text_inputs:
pos_emb = CausalMask(block_size=text_max_block_size)
elif shared_pos_emb:
pos_emb = shared_pos_emb
elif use_pos_emb:
pos_emb = MultiHeadRelativePositionalEmbedding(attn_height=attn_height, name=name and name + "pos_emb")
else:
Expand Down Expand Up @@ -482,6 +485,7 @@ def Beit(
use_abs_pos_emb=False, # [Pos emb args] True for Vit, False for Beit, whether use abcolute positional embedding or relative one in attention blocks
use_abs_pos_emb_on_cls_token=True, # [Pos emb args] False for FlexiViT, no_embed_class in timm. If use_abs_pos_emb is True, whether apply pos_emb on cls_token.
use_rot_pos_emb=False, # [Pos emb args] True for EVA02, False for others
use_shared_pos_emb_for_attn=False, # [Pos emb args] True for beit raw model without any finetune
mlp_ratio=4, # [MLP args]
use_gated_mlp=False, # [MLP args] True for DINOv2 and EVA02
use_norm_mlp=False, # [MLP args] True for EVA02 base and large, False for others.
Expand Down Expand Up @@ -568,6 +572,7 @@ def Beit(
"use_rot_pos_emb": use_rot_pos_emb,
"text_max_block_size": max_block_size if vocab_size > 0 else 0,
"attn_dropout": attn_dropout,
"shared_pos_emb": MultiHeadRelativePositionalEmbedding(attn_height=patch_height, name="shared_pos_emb") if use_shared_pos_emb_for_attn else None,
}

drop_connect_rates = drop_connect_rates_split([depth], 0.0, drop_connect_rate)[0]
Expand Down Expand Up @@ -663,6 +668,14 @@ def keras_model_load_weights_from_pytorch_model(keras_model, timm_vit_model, sav
full_name_align_dict = {"cls_token": -2, "positional_embedding": -1}
else:
full_name_align_dict = {"cls_token": -1, "positional_embedding": -1}

if "shared_pos_emb" in [ii.name for ii in keras_model.layers]:
full_name_align_dict["shared_pos_emb"] = -4
if "attn_gamma" in tail_align_dict:
tail_align_dict["attn_gamma"] += 1
if "mlp_gamma" in tail_align_dict:
tail_align_dict["mlp_gamma"] += 1

additional_transfer = {attention_layers.MultiHeadRelativePositionalEmbedding: lambda ww: [ww[0].T]}

download_and_load.keras_reload_from_torch_model(
Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/llama2/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def causal_self_attention_with_cache(inputs, start_pos=0, max_batch_size=0, bloc

value = layers.Reshape([-1, num_kv_heads, key_dim], name=name + "value_reshape")(value)
if is_kv_cache: # Use KV cache if max_batch_size specified
value = KVCache(max_batch_size=max_batch_size, max_seq_len=block_size, name=name + "value_cahce")([value, start_pos])
value = KVCache(max_batch_size=max_batch_size, max_seq_len=block_size, name=name + "value_cache")([value, start_pos])
value = functional.transpose(value, [0, 2, 1, 3])

if num_kv_heads != num_heads:
Expand Down

1 comment on commit 14a9c07

@leondgarse
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.