Skip to content

Commit

Permalink
temp add hiera
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Jun 10, 2023
1 parent aa6d113 commit 1bd6fd6
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 1 deletion.
1 change: 1 addition & 0 deletions keras_cv_attention_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from keras_cv_attention_models import ghostnet as ghostnetv2 # Will be removed
from keras_cv_attention_models import gpt2
from keras_cv_attention_models import halonet
from keras_cv_attention_models import hiera
from keras_cv_attention_models import iformer
from keras_cv_attention_models import levit
from keras_cv_attention_models import mlp_family
Expand Down
3 changes: 2 additions & 1 deletion keras_cv_attention_models/beit/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

LAYER_NORM_EPSILON = 1e-6


PRETRAINED_DICT = {
"beit_base_patch16": {"imagenet21k-ft1k": {224: "d7102337a13a3983f3b6470de77b5d5c", 384: "76353026477c60f8fdcbcc749fea17b3"}},
"beit_v2_base_patch16": {"imagenet21k-ft1k": {224: "d001dcb67cdda16bfdbb2873ab9b13c8"}},
Expand Down Expand Up @@ -242,7 +243,7 @@ def scaled_dot_product_attention(query, key, value, output_shape, pos_emb=None,
# output = layers.Lambda(lambda xx: functional.matmul(xx[0], xx[1]))([attention_scores, value])
attention_output = attention_scores @ value
output = functional.transpose(attention_output, perm=[0, 2, 1, 3]) # [batch, q_blocks, num_heads, key_dim * attn_ratio]
output = functional.reshape(output, [-1, *blocks, output.shape[2] * output.shape[3]]) # [batch, q_blocks, channel * attn_ratio]
output = functional.reshape(output, [-1, *blocks, np.prod(output.shape[1:]) // np.prod(blocks)]) # [batch, q_blocks, channel * attn_ratio]

if out_weight:
# [batch, hh, ww, num_heads * key_dim] * [num_heads * key_dim, out] --> [batch, hh, ww, out]
Expand Down
2 changes: 2 additions & 0 deletions keras_cv_attention_models/davit/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def multi_head_self_attention_channel(


def __window_partition__(inputs, patch_height, patch_width, window_height, window_width):
""" [B, patch_height * window_height, patch_width * window_width, channel] -> [B * patch_height * patch_width, window_height, window_width, channel] """
input_channel = inputs.shape[-1]
# print(f">>>> window_attention {inputs.shape = }, {patch_height = }, {patch_width = }, {window_height = }, {window_width = }")
# [batch * patch_height, window_height, patch_width, window_width * channel], limit transpose perm <= 4
Expand All @@ -80,6 +81,7 @@ def __window_reverse__(inputs, patch_height, patch_width, window_height, window_


def __grid_window_partition__(inputs, patch_height, patch_width, window_height, window_width):
""" [B, window_height * patch_height, window_width * patch_width , channel] -> [B * patch_height * patch_width, window_height, window_width, channel] """
input_channel = inputs.shape[-1]
nn = functional.reshape(inputs, [-1, window_height, patch_height, window_width * patch_width * input_channel])
nn = functional.transpose(nn, [0, 2, 1, 3]) # [batch, patch_height, window_height, window_width * patch_width * input_channel]
Expand Down
53 changes: 53 additions & 0 deletions keras_cv_attention_models/hiera/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# ___Keras Hiera___
***

## Summary
- Keras implementation of [Github facebookresearch/hiera](https://github.com/facebookresearch/hiera). Paper [PDF 2306.00989 Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/pdf/2306.00989.pdf).
- Model weights ported from official publication.
***

## Models
| Model | Params | FLOPs | Input | Top1 Acc | Download |
| ------------- | ------- | ------- | ----- | -------- | -------- |
| HieraTiny | 27.91M | 4.93G | 224 | 82.8 | |
| HieraSmall | 35.01M | 6.44G | 224 | 83.8 | |
| HieraBase | 51.52M | 9.43G | 224 | 84.5 | |
| HieraBasePlus | 69.90M | 12.71G | 224 | 85.2 | |
| HieraLarge | 213.74M | 40.43G | 224 | 86.1 | |
| HieraHuge | 672.78M | 125.03G | 224 | 86.9 | |
## Usage
```py
from keras_cv_attention_models import hiera, test_images

# Will download and load pretrained imagenet weights.
mm = hiera.HieraBase()

# Run prediction
preds = mm(mm.preprocess_input(test_images.cat()))
print(mm.decode_predictions(preds))
# [('n02124075', 'Egyptian_cat', 0.8966972), ('n02123045', 'tabby', 0.0072582546), ...]
```
## Verification with PyTorch version
```py
""" PyTorch torch_hiera """
sys.path.append('../hiera/')
sys.path.append('../pytorch-image-models/') # Needs timm
import torch
from hiera import hiera as torch_hiera

torch_model = torch_hiera.hiera_base_224()
ss = torch.load('hiera_base_224.pth', map_location=torch.device('cpu'))
torch_model.load_state_dict(ss['model_state'])
_ = torch_model.eval()

""" Keras HieraBase """
from keras_cv_attention_models import hiera
mm = hiera.HieraBase(classifier_activation="softmax")

""" Verification """
inputs = np.random.uniform(size=(1, *mm.input_shape[1:3], 3)).astype("float32")
torch_out = torch_model(torch.from_numpy(inputs).permute(0, 3, 1, 2)).detach().numpy()
keras_out = mm(inputs).numpy()
print(f"{np.allclose(torch_out, keras_out, atol=5e-2) = }")
# np.allclose(torch_out, keras_out, atol=5e-2) = True
```
55 changes: 55 additions & 0 deletions keras_cv_attention_models/hiera/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from keras_cv_attention_models.hiera.hiera import Hiera, HieraTiny, HieraSmall, HieraBase, HieraBasePlus, HieraLarge, HieraHuge

__head_doc__ = """
Keras implementation of [Github facebookresearch/hiera](https://github.com/facebookresearch/hiera).
Paper [PDF 2306.00989 Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/pdf/2306.00989.pdf).
"""

__tail_doc__ = """ input_shape: it should have exactly 3 inputs channels, like `(224, 224, 3)`.
num_classes: number of classes to classify images into. Set `0` to exclude top layers.
activation: activation used in whole model, default `gelu`.
drop_connect_rate: is used for [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382).
Can be value like `0.2`, indicates the drop probability linearly changes from `0 --> 0.2` for `top --> bottom` layers.
A higher value means a higher probability will drop the deep branch.
or `0` to disable (default).
dropout: dropout rate if top layers is included.
classifier_activation: A `str` or callable. The activation function to use on the "top" layer if `num_classes > 0`.
Set `classifier_activation=None` to return the logits of the "top" layer.
pretrained: one of None or "mae_in1k_ft1k". Currently only `HieraBase` weights available.
Will try to download and load pre-trained model weights if not None.
Returns:
A `keras.Model` instance.
"""

Hiera.__doc__ = __head_doc__ + """
Args:
num_blocks: number of blocks in each stack.
embed_dim: basic hidden dims, expand * 2 for each stack.
num_heads: int or list value for num heads in each stack.
use_window_attentions: boolean or list value, each value in the list can also be a list of boolean.
Indicates if use window attention in each stack.
Element value like `[True, False]` means first one is True, others are False.
mlp_ratio: expand ratio for mlp blocks hidden channel.
model_name: string, model name.
""" + __tail_doc__ + """
Model architectures:
| Model | Params | FLOPs | Input | Top1 Acc |
| ------------- | ------- | ------- | ----- | -------- |
| HieraTiny | 27.91M | 4.93G | 224 | 82.8 |
| HieraSmall | 35.01M | 6.44G | 224 | 83.8 |
| HieraBase | 51.52M | 9.43G | 224 | 84.5 |
| HieraBasePlus | 69.90M | 12.71G | 224 | 85.2 |
| HieraLarge | 213.74M | 40.43G | 224 | 86.1 |
| HieraHuge | 672.78M | 125.03G | 224 | 86.9 |
"""

HieraTiny.__doc__ = __head_doc__ + """
Args:
""" + __tail_doc__

HieraSmall.__doc__ = HieraTiny.__doc__
HieraBase.__doc__ = HieraTiny.__doc__
HieraBasePlus.__doc__ = HieraTiny.__doc__
HieraLarge.__doc__ = HieraTiny.__doc__
HieraHuge.__doc__ = HieraTiny.__doc__
206 changes: 206 additions & 0 deletions keras_cv_attention_models/hiera/hiera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import numpy as np
from keras_cv_attention_models import backend
from keras_cv_attention_models.backend import layers, models, functional, image_data_format
from keras_cv_attention_models.models import register_model
from keras_cv_attention_models.attention_layers import (
conv2d_no_bias,
drop_block,
mlp_block,
scaled_dot_product_attention,
PositionalEmbedding,
add_pre_post_process,
)
from keras_cv_attention_models.download_and_load import reload_model_weights

LAYER_NORM_EPSILON = 1e-6

PRETRAINED_DICT = {
"hiera_base": {"mae_in1k_ft1k": {224: "3909d65fbb85d229b1d640b70f6d5a52"}},
}


def mhsa_with_window_extracted_and_strides(
inputs, num_heads=4, key_dim=0, out_shape=None, window_size_prod=-1, strides_prod=1, qkv_bias=True, out_bias=True, attn_dropout=0, name=None
):
_, blocks, cc = inputs.shape
out_shape = cc if out_shape is None else out_shape
key_dim = key_dim if key_dim > 0 else out_shape // num_heads # Note: different from others using input_channels
qkv_out = num_heads * key_dim
qk_scale = 1.0 / (float(key_dim) ** 0.5)
window_size_prod = window_size_prod if window_size_prod > 0 else blocks
window_blocks = blocks // window_size_prod
print(f"{blocks = }, {window_blocks = }, {window_size_prod = }, {num_heads = }")

qkv = layers.Dense(qkv_out * 3, use_bias=qkv_bias, name=name and name + "qkv")(inputs)
query, key, value = functional.split(qkv, 3, axis=-1)

if strides_prod > 1:
query = functional.reshape(query, [-1, window_size_prod // strides_prod, strides_prod, window_blocks * num_heads, key_dim])
query = functional.reduce_max(query, axis=-3)
query = functional.transpose(query, [0, 2, 1, 3])
else:
query = functional.transpose(functional.reshape(query, [-1, window_size_prod, window_blocks * num_heads, key_dim]), [0, 2, 1, 3])
key = functional.transpose(functional.reshape(key, [-1, window_size_prod, window_blocks * num_heads, key_dim]), [0, 2, 3, 1])
value = functional.transpose(functional.reshape(value, [-1, window_size_prod, window_blocks * num_heads, key_dim]), [0, 2, 1, 3])

output_shape = [-1, blocks // strides_prod, out_shape]
return scaled_dot_product_attention(query, key, value, output_shape, out_bias=out_bias, dropout=attn_dropout, name=name)


def attention_mlp_block(inputs, out_channels=-1, num_heads=4, window_size_prod=-1, strides_prod=1, mlp_ratio=4, drop_rate=0, activation="gelu", name=""):
# print(f">>>> {inputs.shape = }, {drop_rate = }")
input_channels = inputs.shape[-1]
out_channels = out_channels if out_channels > 0 else input_channels
pre = layers.LayerNormalization(axis=-1, epsilon=LAYER_NORM_EPSILON, name=name + "attn_ln")(inputs) # "channels_first" also using axis=-1
attn = mhsa_with_window_extracted_and_strides(
pre, num_heads=num_heads, out_shape=out_channels, window_size_prod=window_size_prod, strides_prod=strides_prod, name=name + "attn_"
)
attn = drop_block(attn, drop_rate)
if strides_prod > 1 or out_channels != input_channels:
short = pre if out_channels == input_channels else layers.Dense(out_channels, name=name + "short_dense")(pre)
short = functional.reduce_max(functional.reshape(short, [-1, strides_prod, short.shape[1] // strides_prod, short.shape[-1]]), axis=1)
else:
short = inputs
attn_out = layers.Add(name=name + "attn_out")([short, attn])

""" MLP """
nn = layers.LayerNormalization(axis=-1, epsilon=LAYER_NORM_EPSILON, name=name + "mlp_ln")(attn_out) # "channels_first" also using axis=-1
nn = mlp_block(nn, hidden_dim=int(out_channels * mlp_ratio), activation=activation, name=name + "mlp_")
nn = drop_block(nn, drop_rate)
nn = layers.Add(name=name + "mlp_output")([attn_out, nn])
return nn


def unroll(inputs, strides=[2, 2, 2]):
"""
inputs: [batch, height, width, channels], strides: [2, 2, 2]
-> [batch, height // 8, 2_h3, 2_h2, 2_h1, width // 8, 2_w3, 2_w2, 2_w1, channels]
-> [batch, 2_h1, 2_w1, 2_h2, 2_w2, 2_h3, 2_w3, height // 8, width // 8, channels] # [0, 4, 8, 3, 7, 2, 6, 1, 5, 9]
-> [batch, height * width, channels]
"""
height, width, channels = inputs.shape[1:]
# nn = inputs
# for ii in strides:
# nn = functional.reshape(nn, [-1, nn.shape[-3] // ii, ii, nn.shape[-2] // ii, ii, nn.shape[-1]])
# nn = functional.transpose(nn, [0, 2, 4, 1, 3, 5])
# return functional.reshape(nn, [-1, height * width, channels])

height_strided = height // np.prod(strides)
width_strided = width // np.prod(strides)
inner_shape = [-1, height_strided, *strides, width_strided, *strides, channels]
nn = functional.reshape(inputs, inner_shape)

strides_len = len(strides) + 1
perm = [0] + np.ravel([[ii, ii + strides_len] for ii in range(strides_len, 0, -1)]).tolist() + [2 * strides_len + 1] # [0, 4, 8, 3, 7, 2, 6, 1, 5, 9]
nn = functional.transpose(nn, perm)

return functional.reshape(nn, [-1, height * width, channels])


def Hiera(
num_blocks=[1, 2, 7, 2],
embed_dim=96,
num_heads=[1, 2, 4, 8],
use_window_attentions=[True, True, [True, False], False], # [True, False] means first one is True, others are False
mlp_ratio=4,
# window_ratios=[8, 4, 1, 1],
input_shape=(224, 224, 3),
num_classes=1000,
activation="gelu",
drop_connect_rate=0,
dropout=0,
classifier_activation="softmax",
pretrained=None,
model_name="hiera",
kwargs=None,
):
# Regard input_shape as force using original shape if len(input_shape) == 4,
# else assume channel dimension is the one with min value in input_shape, and put it first or last regarding image_data_format
input_shape = backend.align_input_shape_by_image_data_format(input_shape)
inputs = layers.Input(input_shape)
strides = [1, 2, 2, 2]
window_size_prod = 8 * 8 # Total downsample rates after stem

""" forward_embeddings """
nn = conv2d_no_bias(inputs, embed_dim, 7, strides=4, padding="same", use_bias=True, name="stem_")
nn = nn if image_data_format() == "channels_last" else layers.Permute([2, 3, 1])(nn) # channels_first -> channels_last
nn = PositionalEmbedding(input_height=nn.shape[1], name="positional_embedding")(nn)
height, width = nn.shape[1:-1]
nn = unroll(nn, strides=strides[1:])
# window_ratios = (window_ratios[0] * window_ratios[1]) if isinstance(window_ratios, (list, tuple)) else window_ratios

""" stages """
total_blocks = sum(num_blocks)
global_block_id = 0
for stack_id, (num_block, stride, use_window_attention) in enumerate(zip(num_blocks, strides, use_window_attentions)):
stack_name = "stack{}_".format(stack_id + 1)
cur_out_channels = embed_dim * (2 ** stack_id)
cur_num_heads = num_heads[stack_id] if isinstance(num_heads, (list, tuple)) else num_heads
stack_use_window_attention = use_window_attention if isinstance(use_window_attention, (list, tuple)) else [use_window_attention]
stack_use_window_attention = stack_use_window_attention + stack_use_window_attention[-1:] * (num_block - len(stack_use_window_attention))
window_size_prod //= (stride ** 2)
for block_id in range(num_block):
block_name = stack_name + "block{}_".format(block_id + 1)
block_drop_rate = drop_connect_rate * global_block_id / total_blocks
strides_prod = (stride ** 2) if block_id == 0 else 1
cur_window_size_prod = (window_size_prod * strides_prod) if stack_use_window_attention[block_id] else -1
nn = attention_mlp_block(
nn, cur_out_channels, cur_num_heads, cur_window_size_prod, strides_prod, mlp_ratio, block_drop_rate, activation=activation, name=block_name
)
global_block_id += 1
height, width = height // stride, width // stride # [TODO] reroll
nn = functional.reshape(nn, [-1, height, width, nn.shape[-1]])
nn = nn if image_data_format() == "channels_last" else layers.Permute([3, 1, 2])(nn) # channels_last -> channels_first

if num_classes > 0:
nn = layers.GlobalAveragePooling2D(name="avg_pool")(nn)
nn = layers.LayerNormalization(axis=-1, epsilon=LAYER_NORM_EPSILON, name="post_ln")(nn)
if dropout > 0:
nn = layers.Dropout(dropout, name="head_drop")(nn)
nn = layers.Dense(num_classes, dtype="float32", activation=classifier_activation, name="predictions")(nn)

model = models.Model(inputs, nn, name=model_name)
add_pre_post_process(model, rescale_mode="torch")
reload_model_weights(model, PRETRAINED_DICT, "hiera", pretrained, PositionalEmbedding)
return model


@register_model
def HieraTiny(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs):
return Hiera(**locals(), model_name="hiera_tiny", **kwargs)


@register_model
def HieraSmall(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs):
num_blocks = [1, 2, 11, 2]
return Hiera(**locals(), model_name="hiera_small", **kwargs)


@register_model
def HieraBase(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="mae_in1k_ft1k", **kwargs):
num_blocks = [2, 3, 16, 3]
return Hiera(**locals(), model_name="hiera_base", **kwargs)


@register_model
def HieraBasePlus(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs):
num_blocks = [2, 3, 16, 3]
embed_dim = 112
num_heads = [2, 4, 8, 16]
return Hiera(**locals(), model_name="hiera_base_plus", **kwargs)


@register_model
def HieraLarge(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs):
num_blocks = [2, 6, 36, 4]
embed_dim = 144
num_heads = [2, 4, 8, 16]
return Hiera(**locals(), model_name="hiera_large", **kwargs)


@register_model
def HieraHuge(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs):
num_blocks = [2, 6, 36, 4]
embed_dim = 256
num_heads = [4, 8, 16, 32]
return Hiera(**locals(), model_name="hiera_huge", **kwargs)

0 comments on commit 1bd6fd6

Please sign in to comment.