Skip to content

Commit

Permalink
temp adding fastervit
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Jun 17, 2023
1 parent 1c35306 commit 3fe2728
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 30 deletions.
9 changes: 8 additions & 1 deletion keras_cv_attention_models/attention_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@
from keras_cv_attention_models.coatnet.coatnet import mhsa_with_multi_head_relative_position_embedding
from keras_cv_attention_models.cmt.cmt import light_mhsa_with_multi_head_relative_position_embedding, BiasPositionalEmbedding
from keras_cv_attention_models.uniformer.uniformer import multi_head_self_attention
from keras_cv_attention_models.davit.davit import multi_head_self_attention_channel, window_attention
from keras_cv_attention_models.davit.davit import (
multi_head_self_attention_channel,
window_attention,
window_partition,
window_reverse,
pad_to_divisible_by_window_size,
reverse_padded_for_window_size,
)
from keras_cv_attention_models.edgenext.edgenext import PositionalEncodingFourier, cross_covariance_attention
from keras_cv_attention_models.efficientvit.efficientvit_m import cascaded_mhsa_with_multi_head_position
from keras_cv_attention_models.mobilevit.mobilevit import linear_self_attention
Expand Down
67 changes: 45 additions & 22 deletions keras_cv_attention_models/davit/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,32 @@ def multi_head_self_attention_channel(
return attention_output


def __window_partition__(inputs, patch_height, patch_width, window_height, window_width):
def pad_to_divisible_by_window_size(inputs, window_size):
window_size = window_size if isinstance(window_size, (list, tuple)) else [window_size, window_size]
window_height = window_size[0] if window_size[0] < inputs.shape[1] else inputs.shape[1]
window_width = window_size[1] if window_size[1] < inputs.shape[2] else inputs.shape[2]

# window_partition, partition windows, ceil mode
patch_height, patch_width = int(math.ceil(inputs.shape[1] / window_height)), int(math.ceil(inputs.shape[2] / window_width))
padding_height, padding_width = patch_height * window_height - inputs.shape[1], patch_width * window_width - inputs.shape[2]
# print(f">>>> window_attention {inputs.shape = }, {padding_height = }, {padding_width = }")
if padding_height or padding_width:
inputs = functional.pad(inputs, [[0, 0], [0, padding_height], [0, padding_width], [0, 0]])
return inputs, window_height, window_width, padding_height, padding_width


def reverse_padded_for_window_size(inputs, padding_height, padding_width):
if padding_height or padding_width:
inputs = inputs[:, : inputs.shape[1] - padding_height, : inputs.shape[2] - padding_width, :] # In case padding_height or padding_width is 0
return inputs


def window_partition(inputs, window_height, window_width=-1):
"""[B, patch_height * window_height, patch_width * window_width, channel] -> [B * patch_height * patch_width, window_height, window_width, channel]"""
input_channel = inputs.shape[-1]
window_width = window_width if window_width > 0 else window_height
patch_height, patch_width = inputs.shape[1] // window_height, inputs.shape[2] // window_width

# print(f">>>> window_attention {inputs.shape = }, {patch_height = }, {patch_width = }, {window_height = }, {window_width = }")
# [batch * patch_height, window_height, patch_width, window_width * channel], limit transpose perm <= 4
nn = functional.reshape(inputs, [-1, window_height, patch_width, window_width * input_channel])
Expand All @@ -71,18 +94,25 @@ def __window_partition__(inputs, patch_height, patch_width, window_height, windo
return nn


def __window_reverse__(inputs, patch_height, patch_width, window_height, window_width):
def window_reverse(inputs, patch_height, patch_width=-1):
"""[B * patch_height * patch_width, window_height, window_width, channel] -> [B, patch_height * window_height, patch_width * window_width, channel]"""
input_channel = inputs.shape[-1]
patch_width = patch_width if patch_width > 0 else patch_height
window_height, window_width = inputs.shape[1], inputs.shape[2]

# [batch * patch_height, patch_width, window_height, window_width * input_channel], limit transpose perm <= 4
nn = functional.reshape(inputs, [-1, patch_width, window_height, window_width * input_channel])
nn = functional.transpose(nn, [0, 2, 1, 3]) # [batch * patch_height, window_height, patch_width, window_width * input_channel]
nn = functional.reshape(nn, [-1, patch_height * window_height, patch_width * window_width, input_channel])
return nn


def __grid_window_partition__(inputs, patch_height, patch_width, window_height, window_width):
def grid_window_partition(inputs, window_height, window_width=-1):
"""[B, window_height * patch_height, window_width * patch_width , channel] -> [B * patch_height * patch_width, window_height, window_width, channel]"""
input_channel = inputs.shape[-1]
window_width = window_width if window_width > 0 else window_height
patch_height, patch_width = inputs.shape[1] // window_height, inputs.shape[2] // window_width

nn = functional.reshape(inputs, [-1, window_height, patch_height, window_width * patch_width * input_channel])
nn = functional.transpose(nn, [0, 2, 1, 3]) # [batch, patch_height, window_height, window_width * patch_width * input_channel]
nn = functional.reshape(nn, [-1, window_height * window_width, patch_width, input_channel])
Expand All @@ -91,8 +121,12 @@ def __grid_window_partition__(inputs, patch_height, patch_width, window_height,
return nn


def __grid_window_reverse__(inputs, patch_height, patch_width, window_height, window_width):
def grid_window_reverse(inputs, patch_height, patch_width=-1):
"""[B * patch_height * patch_width, window_height, window_width, channel] -> [B, window_height * patch_height, window_width * patch_width , channel]"""
input_channel = inputs.shape[-1]
patch_width = patch_width if patch_width > 0 else patch_height
window_height, window_width = inputs.shape[1], inputs.shape[2]

nn = functional.reshape(inputs, [-1, patch_width, window_height * window_width, input_channel])
nn = functional.transpose(nn, [0, 2, 1, 3]) # [batch * patch_height, window_height * window_width, patch_width, input_channel]
nn = functional.reshape(nn, [-1, patch_height, window_height, window_width * patch_width * input_channel])
Expand All @@ -103,22 +137,13 @@ def __grid_window_reverse__(inputs, patch_height, patch_width, window_height, wi

def window_attention(inputs, window_size, num_heads=4, is_grid=False, attention_block=None, data_format="channels_last", name=None, **kwargs):
inputs = inputs if data_format == "channels_last" else functional.transpose(inputs, [0, 2, 3, 1])

window_size = window_size if isinstance(window_size, (list, tuple)) else [window_size, window_size]
window_height = window_size[0] if window_size[0] < inputs.shape[1] else inputs.shape[1]
window_width = window_size[1] if window_size[1] < inputs.shape[2] else inputs.shape[2]

# window_partition, partition windows, ceil mode
patch_height, patch_width = int(math.ceil(inputs.shape[1] / window_height)), int(math.ceil(inputs.shape[2] / window_width))
should_pad_hh, should_pad_ww = patch_height * window_height - inputs.shape[1], patch_width * window_width - inputs.shape[2]
# print(f">>>> window_attention {inputs.shape = }, {should_pad_hh = }, {should_pad_ww = }")
if should_pad_hh or should_pad_ww:
inputs = functional.pad(inputs, [[0, 0], [0, should_pad_hh], [0, should_pad_ww], [0, 0]])
inputs, window_height, window_width, padding_height, padding_width = pad_to_divisible_by_window_size(inputs, window_size)
patch_height, patch_width = inputs.shape[1] // window_height, inputs.shape[2] // window_width

if is_grid:
nn = __grid_window_partition__(inputs, patch_height, patch_width, window_height, window_width)
nn = grid_window_partition(inputs, window_height, window_width)
else:
nn = __window_partition__(inputs, patch_height, patch_width, window_height, window_width)
nn = window_partition(inputs, window_height, window_width)

if attention_block:
nn = nn if data_format == "channels_last" else functional.transpose(nn, [0, 3, 1, 2])
Expand All @@ -129,13 +154,11 @@ def window_attention(inputs, window_size, num_heads=4, is_grid=False, attention_

# window_reverse, merge windows
if is_grid:
nn = __grid_window_reverse__(nn, patch_height, patch_width, window_height, window_width)
nn = grid_window_reverse(nn, patch_height, patch_width)
else:
nn = __window_reverse__(nn, patch_height, patch_width, window_height, window_width)

if should_pad_hh or should_pad_ww:
nn = nn[:, : nn.shape[1] - should_pad_hh, : nn.shape[2] - should_pad_ww, :] # In case should_pad_hh or should_pad_ww is 0
nn = window_reverse(nn, patch_height, patch_width)

nn = reverse_padded_for_window_size(nn, padding_height, padding_width)
return nn if data_format == "channels_last" else functional.transpose(nn, [0, 3, 1, 2])


Expand Down
5 changes: 0 additions & 5 deletions keras_cv_attention_models/edgenext/edgenext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,11 @@
from keras_cv_attention_models.models import register_model
from keras_cv_attention_models.attention_layers import (
ChannelAffine,
activation_by_name,
conv2d_no_bias,
depthwise_conv2d_no_bias,
drop_block,
layer_norm,
mlp_block,
multi_head_self_attention,
output_block,
qkv_to_multi_head_channels_last_format,
scaled_dot_product_attention,
add_pre_post_process,
)
from keras_cv_attention_models.download_and_load import reload_model_weights
Expand Down
Loading

0 comments on commit 3fe2728

Please sign in to comment.