Skip to content

Commit

Permalink
beit support non-square input_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed May 10, 2022
1 parent 3442b1f commit fa5acb1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions keras_cv_attention_models/beit/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, with_cls_token=True, attn_height=-1, num_heads=-1, **kwargs):
def build(self, attn_shape):
# print(attn_shape)
if self.attn_height == -1:
height = width = int(tf.math.sqrt(float(attn_shape[2] - self.cls_token_len))) # assume hh == ww, e.g. 14
height = width = int(tf.math.sqrt(float(attn_shape[2] - self.cls_token_len))) # hh == ww, e.g. 14
else:
height = self.attn_height
width = int(float(attn_shape[2] - self.cls_token_len) / height)
Expand Down Expand Up @@ -121,7 +121,7 @@ def show_pos_emb(self, rows=1, base_size=2):
return fig


def attention_block(inputs, num_heads=4, key_dim=0, out_weight=True, out_bias=False, qv_bias=True, attn_dropout=0, name=None):
def attention_block(inputs, num_heads=4, key_dim=0, out_weight=True, out_bias=False, qv_bias=True, attn_height=-1, attn_dropout=0, name=None):
_, bb, cc = inputs.shape
key_dim = key_dim if key_dim > 0 else cc // num_heads
qk_scale = float(1.0 / tf.math.sqrt(tf.cast(key_dim, "float32")))
Expand All @@ -146,7 +146,7 @@ def attention_block(inputs, num_heads=4, key_dim=0, out_weight=True, out_bias=Fa
query *= qk_scale
# [batch, num_heads, cls_token + hh * ww, cls_token + hh * ww]
attention_scores = keras.layers.Lambda(lambda xx: tf.matmul(xx[0], xx[1]))([query, key])
attention_scores = MultiHeadRelativePositionalEmbedding(name=name and name + "pos_emb")(attention_scores)
attention_scores = MultiHeadRelativePositionalEmbedding(attn_height=attn_height, name=name and name + "pos_emb")(attention_scores)
# attention_scores = tf.nn.softmax(attention_scores, axis=-1, name=name and name + "_attention_scores")
attention_scores = keras.layers.Softmax(axis=-1, name=name and name + "attention_scores")(attention_scores)

Expand Down Expand Up @@ -225,6 +225,7 @@ def Beit(

""" forward_embeddings """
nn = conv2d_no_bias(inputs, embed_dim, patch_size, strides=patch_size, padding="valid", use_bias=True, name="stem_")
patch_height = nn.shape[1]
nn = keras.layers.Reshape([-1, nn.shape[-1]])(nn)
nn = ClassToken(name="cls_token")(nn)

Expand All @@ -234,6 +235,7 @@ def Beit(
"qv_bias": attn_qv_bias,
"out_weight": attn_out_weight,
"out_bias": attn_out_bias,
"attn_height": patch_height,
"attn_dropout": attn_dropout,
}

Expand Down

0 comments on commit fa5acb1

Please sign in to comment.