Skip to content

Commit

Permalink
fix swinv2 for TPU saving model
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Apr 25, 2022
1 parent c6d10a0 commit ac522ab
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def window_multi_head_self_attention(inputs, filters=-1, num_heads=4, meta_hidde
if mask is not None:
query_blocks = attn.shape[2]
attn = tf.reshape(attn, [-1, mask.shape[0], num_heads, query_blocks, query_blocks])
attn += tf.expand_dims(tf.expand_dims(mask, 1), 0) # expand dims on batch and num_heads
# attn += tf.expand_dims(tf.expand_dims(mask, 1), 0) # expand dims on batch and num_heads
mask = tf.expand_dims(tf.expand_dims(mask, 1), 0) # expand dims on batch and num_heads
attn = keras.layers.Add()([attn, mask])
attn = tf.reshape(attn, [-1, num_heads, query_blocks, query_blocks])
attention_scores = keras.layers.Softmax(axis=-1, name=name and name + "attention_scores")(attn)

Expand All @@ -107,7 +109,7 @@ def window_multi_head_self_attention(inputs, filters=-1, num_heads=4, meta_hidde


def make_window_attention_mask(height, width, window_height, window_width, shift_height, shift_width):
float_dtype = tf.keras.mixed_precision.global_policy().compute_dtype
# float_dtype = tf.keras.mixed_precision.global_policy().compute_dtype
hh_split = [0, height - window_height, height - shift_height, height]
ww_split = [0, width - window_width, width - shift_width, width]
mask_value, total_ww, mask = 0, len(ww_split) - 1, []
Expand All @@ -123,7 +125,7 @@ def make_window_attention_mask(height, width, window_height, window_width, shift
mask = tf.transpose(mask, [0, 2, 1, 3])
mask = tf.reshape(mask, [-1, window_height * window_width])
attn_mask = tf.expand_dims(mask, 1) - tf.expand_dims(mask, 2)
return tf.cast(tf.where(attn_mask != 0, -100, 0), float_dtype)
return tf.cast(tf.where(attn_mask != 0, -100, 0), "float32")


def shifted_window_attention(inputs, window_size, num_heads=4, shift_size=0, name=""):
Expand Down

1 comment on commit ac522ab

@leondgarse
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For #54

Please sign in to comment.