-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aa6d113
commit 1bd6fd6
Showing
6 changed files
with
319 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# ___Keras Hiera___ | ||
*** | ||
|
||
## Summary | ||
- Keras implementation of [Github facebookresearch/hiera](https://github.com/facebookresearch/hiera). Paper [PDF 2306.00989 Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/pdf/2306.00989.pdf). | ||
- Model weights ported from official publication. | ||
*** | ||
|
||
## Models | ||
| Model | Params | FLOPs | Input | Top1 Acc | Download | | ||
| ------------- | ------- | ------- | ----- | -------- | -------- | | ||
| HieraTiny | 27.91M | 4.93G | 224 | 82.8 | | | ||
| HieraSmall | 35.01M | 6.44G | 224 | 83.8 | | | ||
| HieraBase | 51.52M | 9.43G | 224 | 84.5 | | | ||
| HieraBasePlus | 69.90M | 12.71G | 224 | 85.2 | | | ||
| HieraLarge | 213.74M | 40.43G | 224 | 86.1 | | | ||
| HieraHuge | 672.78M | 125.03G | 224 | 86.9 | | | ||
## Usage | ||
```py | ||
from keras_cv_attention_models import hiera, test_images | ||
|
||
# Will download and load pretrained imagenet weights. | ||
mm = hiera.HieraBase() | ||
|
||
# Run prediction | ||
preds = mm(mm.preprocess_input(test_images.cat())) | ||
print(mm.decode_predictions(preds)) | ||
# [('n02124075', 'Egyptian_cat', 0.8966972), ('n02123045', 'tabby', 0.0072582546), ...] | ||
``` | ||
## Verification with PyTorch version | ||
```py | ||
""" PyTorch torch_hiera """ | ||
sys.path.append('../hiera/') | ||
sys.path.append('../pytorch-image-models/') # Needs timm | ||
import torch | ||
from hiera import hiera as torch_hiera | ||
|
||
torch_model = torch_hiera.hiera_base_224() | ||
ss = torch.load('hiera_base_224.pth', map_location=torch.device('cpu')) | ||
torch_model.load_state_dict(ss['model_state']) | ||
_ = torch_model.eval() | ||
|
||
""" Keras HieraBase """ | ||
from keras_cv_attention_models import hiera | ||
mm = hiera.HieraBase(classifier_activation="softmax") | ||
|
||
""" Verification """ | ||
inputs = np.random.uniform(size=(1, *mm.input_shape[1:3], 3)).astype("float32") | ||
torch_out = torch_model(torch.from_numpy(inputs).permute(0, 3, 1, 2)).detach().numpy() | ||
keras_out = mm(inputs).numpy() | ||
print(f"{np.allclose(torch_out, keras_out, atol=5e-2) = }") | ||
# np.allclose(torch_out, keras_out, atol=5e-2) = True | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from keras_cv_attention_models.hiera.hiera import Hiera, HieraTiny, HieraSmall, HieraBase, HieraBasePlus, HieraLarge, HieraHuge | ||
|
||
__head_doc__ = """ | ||
Keras implementation of [Github facebookresearch/hiera](https://github.com/facebookresearch/hiera). | ||
Paper [PDF 2306.00989 Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/pdf/2306.00989.pdf). | ||
""" | ||
|
||
__tail_doc__ = """ input_shape: it should have exactly 3 inputs channels, like `(224, 224, 3)`. | ||
num_classes: number of classes to classify images into. Set `0` to exclude top layers. | ||
activation: activation used in whole model, default `gelu`. | ||
drop_connect_rate: is used for [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382). | ||
Can be value like `0.2`, indicates the drop probability linearly changes from `0 --> 0.2` for `top --> bottom` layers. | ||
A higher value means a higher probability will drop the deep branch. | ||
or `0` to disable (default). | ||
dropout: dropout rate if top layers is included. | ||
classifier_activation: A `str` or callable. The activation function to use on the "top" layer if `num_classes > 0`. | ||
Set `classifier_activation=None` to return the logits of the "top" layer. | ||
pretrained: one of None or "mae_in1k_ft1k". Currently only `HieraBase` weights available. | ||
Will try to download and load pre-trained model weights if not None. | ||
Returns: | ||
A `keras.Model` instance. | ||
""" | ||
|
||
Hiera.__doc__ = __head_doc__ + """ | ||
Args: | ||
num_blocks: number of blocks in each stack. | ||
embed_dim: basic hidden dims, expand * 2 for each stack. | ||
num_heads: int or list value for num heads in each stack. | ||
use_window_attentions: boolean or list value, each value in the list can also be a list of boolean. | ||
Indicates if use window attention in each stack. | ||
Element value like `[True, False]` means first one is True, others are False. | ||
mlp_ratio: expand ratio for mlp blocks hidden channel. | ||
model_name: string, model name. | ||
""" + __tail_doc__ + """ | ||
Model architectures: | ||
| Model | Params | FLOPs | Input | Top1 Acc | | ||
| ------------- | ------- | ------- | ----- | -------- | | ||
| HieraTiny | 27.91M | 4.93G | 224 | 82.8 | | ||
| HieraSmall | 35.01M | 6.44G | 224 | 83.8 | | ||
| HieraBase | 51.52M | 9.43G | 224 | 84.5 | | ||
| HieraBasePlus | 69.90M | 12.71G | 224 | 85.2 | | ||
| HieraLarge | 213.74M | 40.43G | 224 | 86.1 | | ||
| HieraHuge | 672.78M | 125.03G | 224 | 86.9 | | ||
""" | ||
|
||
HieraTiny.__doc__ = __head_doc__ + """ | ||
Args: | ||
""" + __tail_doc__ | ||
|
||
HieraSmall.__doc__ = HieraTiny.__doc__ | ||
HieraBase.__doc__ = HieraTiny.__doc__ | ||
HieraBasePlus.__doc__ = HieraTiny.__doc__ | ||
HieraLarge.__doc__ = HieraTiny.__doc__ | ||
HieraHuge.__doc__ = HieraTiny.__doc__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import numpy as np | ||
from keras_cv_attention_models import backend | ||
from keras_cv_attention_models.backend import layers, models, functional, image_data_format | ||
from keras_cv_attention_models.models import register_model | ||
from keras_cv_attention_models.attention_layers import ( | ||
conv2d_no_bias, | ||
drop_block, | ||
mlp_block, | ||
scaled_dot_product_attention, | ||
PositionalEmbedding, | ||
add_pre_post_process, | ||
) | ||
from keras_cv_attention_models.download_and_load import reload_model_weights | ||
|
||
LAYER_NORM_EPSILON = 1e-6 | ||
|
||
PRETRAINED_DICT = { | ||
"hiera_base": {"mae_in1k_ft1k": {224: "3909d65fbb85d229b1d640b70f6d5a52"}}, | ||
} | ||
|
||
|
||
def mhsa_with_window_extracted_and_strides( | ||
inputs, num_heads=4, key_dim=0, out_shape=None, window_size_prod=-1, strides_prod=1, qkv_bias=True, out_bias=True, attn_dropout=0, name=None | ||
): | ||
_, blocks, cc = inputs.shape | ||
out_shape = cc if out_shape is None else out_shape | ||
key_dim = key_dim if key_dim > 0 else out_shape // num_heads # Note: different from others using input_channels | ||
qkv_out = num_heads * key_dim | ||
qk_scale = 1.0 / (float(key_dim) ** 0.5) | ||
window_size_prod = window_size_prod if window_size_prod > 0 else blocks | ||
window_blocks = blocks // window_size_prod | ||
print(f"{blocks = }, {window_blocks = }, {window_size_prod = }, {num_heads = }") | ||
|
||
qkv = layers.Dense(qkv_out * 3, use_bias=qkv_bias, name=name and name + "qkv")(inputs) | ||
query, key, value = functional.split(qkv, 3, axis=-1) | ||
|
||
if strides_prod > 1: | ||
query = functional.reshape(query, [-1, window_size_prod // strides_prod, strides_prod, window_blocks * num_heads, key_dim]) | ||
query = functional.reduce_max(query, axis=-3) | ||
query = functional.transpose(query, [0, 2, 1, 3]) | ||
else: | ||
query = functional.transpose(functional.reshape(query, [-1, window_size_prod, window_blocks * num_heads, key_dim]), [0, 2, 1, 3]) | ||
key = functional.transpose(functional.reshape(key, [-1, window_size_prod, window_blocks * num_heads, key_dim]), [0, 2, 3, 1]) | ||
value = functional.transpose(functional.reshape(value, [-1, window_size_prod, window_blocks * num_heads, key_dim]), [0, 2, 1, 3]) | ||
|
||
output_shape = [-1, blocks // strides_prod, out_shape] | ||
return scaled_dot_product_attention(query, key, value, output_shape, out_bias=out_bias, dropout=attn_dropout, name=name) | ||
|
||
|
||
def attention_mlp_block(inputs, out_channels=-1, num_heads=4, window_size_prod=-1, strides_prod=1, mlp_ratio=4, drop_rate=0, activation="gelu", name=""): | ||
# print(f">>>> {inputs.shape = }, {drop_rate = }") | ||
input_channels = inputs.shape[-1] | ||
out_channels = out_channels if out_channels > 0 else input_channels | ||
pre = layers.LayerNormalization(axis=-1, epsilon=LAYER_NORM_EPSILON, name=name + "attn_ln")(inputs) # "channels_first" also using axis=-1 | ||
attn = mhsa_with_window_extracted_and_strides( | ||
pre, num_heads=num_heads, out_shape=out_channels, window_size_prod=window_size_prod, strides_prod=strides_prod, name=name + "attn_" | ||
) | ||
attn = drop_block(attn, drop_rate) | ||
if strides_prod > 1 or out_channels != input_channels: | ||
short = pre if out_channels == input_channels else layers.Dense(out_channels, name=name + "short_dense")(pre) | ||
short = functional.reduce_max(functional.reshape(short, [-1, strides_prod, short.shape[1] // strides_prod, short.shape[-1]]), axis=1) | ||
else: | ||
short = inputs | ||
attn_out = layers.Add(name=name + "attn_out")([short, attn]) | ||
|
||
""" MLP """ | ||
nn = layers.LayerNormalization(axis=-1, epsilon=LAYER_NORM_EPSILON, name=name + "mlp_ln")(attn_out) # "channels_first" also using axis=-1 | ||
nn = mlp_block(nn, hidden_dim=int(out_channels * mlp_ratio), activation=activation, name=name + "mlp_") | ||
nn = drop_block(nn, drop_rate) | ||
nn = layers.Add(name=name + "mlp_output")([attn_out, nn]) | ||
return nn | ||
|
||
|
||
def unroll(inputs, strides=[2, 2, 2]): | ||
""" | ||
inputs: [batch, height, width, channels], strides: [2, 2, 2] | ||
-> [batch, height // 8, 2_h3, 2_h2, 2_h1, width // 8, 2_w3, 2_w2, 2_w1, channels] | ||
-> [batch, 2_h1, 2_w1, 2_h2, 2_w2, 2_h3, 2_w3, height // 8, width // 8, channels] # [0, 4, 8, 3, 7, 2, 6, 1, 5, 9] | ||
-> [batch, height * width, channels] | ||
""" | ||
height, width, channels = inputs.shape[1:] | ||
# nn = inputs | ||
# for ii in strides: | ||
# nn = functional.reshape(nn, [-1, nn.shape[-3] // ii, ii, nn.shape[-2] // ii, ii, nn.shape[-1]]) | ||
# nn = functional.transpose(nn, [0, 2, 4, 1, 3, 5]) | ||
# return functional.reshape(nn, [-1, height * width, channels]) | ||
|
||
height_strided = height // np.prod(strides) | ||
width_strided = width // np.prod(strides) | ||
inner_shape = [-1, height_strided, *strides, width_strided, *strides, channels] | ||
nn = functional.reshape(inputs, inner_shape) | ||
|
||
strides_len = len(strides) + 1 | ||
perm = [0] + np.ravel([[ii, ii + strides_len] for ii in range(strides_len, 0, -1)]).tolist() + [2 * strides_len + 1] # [0, 4, 8, 3, 7, 2, 6, 1, 5, 9] | ||
nn = functional.transpose(nn, perm) | ||
|
||
return functional.reshape(nn, [-1, height * width, channels]) | ||
|
||
|
||
def Hiera( | ||
num_blocks=[1, 2, 7, 2], | ||
embed_dim=96, | ||
num_heads=[1, 2, 4, 8], | ||
use_window_attentions=[True, True, [True, False], False], # [True, False] means first one is True, others are False | ||
mlp_ratio=4, | ||
# window_ratios=[8, 4, 1, 1], | ||
input_shape=(224, 224, 3), | ||
num_classes=1000, | ||
activation="gelu", | ||
drop_connect_rate=0, | ||
dropout=0, | ||
classifier_activation="softmax", | ||
pretrained=None, | ||
model_name="hiera", | ||
kwargs=None, | ||
): | ||
# Regard input_shape as force using original shape if len(input_shape) == 4, | ||
# else assume channel dimension is the one with min value in input_shape, and put it first or last regarding image_data_format | ||
input_shape = backend.align_input_shape_by_image_data_format(input_shape) | ||
inputs = layers.Input(input_shape) | ||
strides = [1, 2, 2, 2] | ||
window_size_prod = 8 * 8 # Total downsample rates after stem | ||
|
||
""" forward_embeddings """ | ||
nn = conv2d_no_bias(inputs, embed_dim, 7, strides=4, padding="same", use_bias=True, name="stem_") | ||
nn = nn if image_data_format() == "channels_last" else layers.Permute([2, 3, 1])(nn) # channels_first -> channels_last | ||
nn = PositionalEmbedding(input_height=nn.shape[1], name="positional_embedding")(nn) | ||
height, width = nn.shape[1:-1] | ||
nn = unroll(nn, strides=strides[1:]) | ||
# window_ratios = (window_ratios[0] * window_ratios[1]) if isinstance(window_ratios, (list, tuple)) else window_ratios | ||
|
||
""" stages """ | ||
total_blocks = sum(num_blocks) | ||
global_block_id = 0 | ||
for stack_id, (num_block, stride, use_window_attention) in enumerate(zip(num_blocks, strides, use_window_attentions)): | ||
stack_name = "stack{}_".format(stack_id + 1) | ||
cur_out_channels = embed_dim * (2 ** stack_id) | ||
cur_num_heads = num_heads[stack_id] if isinstance(num_heads, (list, tuple)) else num_heads | ||
stack_use_window_attention = use_window_attention if isinstance(use_window_attention, (list, tuple)) else [use_window_attention] | ||
stack_use_window_attention = stack_use_window_attention + stack_use_window_attention[-1:] * (num_block - len(stack_use_window_attention)) | ||
window_size_prod //= (stride ** 2) | ||
for block_id in range(num_block): | ||
block_name = stack_name + "block{}_".format(block_id + 1) | ||
block_drop_rate = drop_connect_rate * global_block_id / total_blocks | ||
strides_prod = (stride ** 2) if block_id == 0 else 1 | ||
cur_window_size_prod = (window_size_prod * strides_prod) if stack_use_window_attention[block_id] else -1 | ||
nn = attention_mlp_block( | ||
nn, cur_out_channels, cur_num_heads, cur_window_size_prod, strides_prod, mlp_ratio, block_drop_rate, activation=activation, name=block_name | ||
) | ||
global_block_id += 1 | ||
height, width = height // stride, width // stride # [TODO] reroll | ||
nn = functional.reshape(nn, [-1, height, width, nn.shape[-1]]) | ||
nn = nn if image_data_format() == "channels_last" else layers.Permute([3, 1, 2])(nn) # channels_last -> channels_first | ||
|
||
if num_classes > 0: | ||
nn = layers.GlobalAveragePooling2D(name="avg_pool")(nn) | ||
nn = layers.LayerNormalization(axis=-1, epsilon=LAYER_NORM_EPSILON, name="post_ln")(nn) | ||
if dropout > 0: | ||
nn = layers.Dropout(dropout, name="head_drop")(nn) | ||
nn = layers.Dense(num_classes, dtype="float32", activation=classifier_activation, name="predictions")(nn) | ||
|
||
model = models.Model(inputs, nn, name=model_name) | ||
add_pre_post_process(model, rescale_mode="torch") | ||
reload_model_weights(model, PRETRAINED_DICT, "hiera", pretrained, PositionalEmbedding) | ||
return model | ||
|
||
|
||
@register_model | ||
def HieraTiny(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs): | ||
return Hiera(**locals(), model_name="hiera_tiny", **kwargs) | ||
|
||
|
||
@register_model | ||
def HieraSmall(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs): | ||
num_blocks = [1, 2, 11, 2] | ||
return Hiera(**locals(), model_name="hiera_small", **kwargs) | ||
|
||
|
||
@register_model | ||
def HieraBase(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained="mae_in1k_ft1k", **kwargs): | ||
num_blocks = [2, 3, 16, 3] | ||
return Hiera(**locals(), model_name="hiera_base", **kwargs) | ||
|
||
|
||
@register_model | ||
def HieraBasePlus(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs): | ||
num_blocks = [2, 3, 16, 3] | ||
embed_dim = 112 | ||
num_heads = [2, 4, 8, 16] | ||
return Hiera(**locals(), model_name="hiera_base_plus", **kwargs) | ||
|
||
|
||
@register_model | ||
def HieraLarge(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs): | ||
num_blocks = [2, 6, 36, 4] | ||
embed_dim = 144 | ||
num_heads = [2, 4, 8, 16] | ||
return Hiera(**locals(), model_name="hiera_large", **kwargs) | ||
|
||
|
||
@register_model | ||
def HieraHuge(input_shape=(224, 224, 3), num_classes=1000, activation="gelu", classifier_activation="softmax", pretrained=None, **kwargs): | ||
num_blocks = [2, 6, 36, 4] | ||
embed_dim = 256 | ||
num_heads = [4, 8, 16, 32] | ||
return Hiera(**locals(), model_name="hiera_huge", **kwargs) |