diff --git a/backbones/vit.py b/backbones/vit.py index 7dd27d5..a7b7fd2 100644 --- a/backbones/vit.py +++ b/backbones/vit.py @@ -3,7 +3,6 @@ # Copyright (c) 2021 edwardyehuang (https://github.com/edwardyehuang) # ================================================================ -# WIP DO NOT USE import tensorflow as tf @@ -12,98 +11,96 @@ from iseg.layers.common_layers import PatchEmbed -class VisionTransformer(tf.keras.Model): - def __init__( - self, - patch_size, - num_layer, - num_head, - filters=768, - mlp_filters=4096, - dropout_rate=0.1, - return_endpoints=False, - name=None - ): +def resize_pos_embed( + pos_embed, # [1, hw, C] + target_size, + num_extra_tokens=1 +): + + assert len(pos_embed.shape) == 3 - super().__init__(name=name) + extra_tokens = pos_embed[:, :num_extra_tokens] + pos_embed = pos_embed[:, num_extra_tokens:] - self.patch_size = patch_size - self.num_layer = num_layer - self.num_head = num_head - self.filters = filters - self.mlp_filters = mlp_filters - self.dropout_rate = dropout_rate - - self.return_endpoints = return_endpoints - - def build(self, input_shape): - - self.patch_encoder = PatchEmbed( - patch_size=(self.patch_size, self.patch_size), - embed_filters=self.filters, - name="embedding", - ) - - self.positional_encoder = PositionalEncoder(name="posembed_input") - - self.blocks = [] - - for i in range(self.num_layer): - self.blocks += [ - TransformerBlock( - self.mlp_filters, num_heads=self.num_head, dropout_rate=self.dropout_rate, name=f"encoderblock_{i}" - ) - ] + pos_embed_shape = tf.shape(pos_embed) - def call(self, inputs, training=None): + batch_size = pos_embed_shape[0] + pos_embed_length = pos_embed_shape[1] - x = inputs + pos_embed_length = tf.cast(pos_embed_length, tf.float32) - x = self.patch_encoder(x) + pos_embed_axis_length = tf.cast( + tf.sqrt(pos_embed_length), dtype=tf.int32 + ) - batch_size, height, width, channels = get_tensor_shape(x) + pos_embed_axis_length = tf.cast(pos_embed_axis_length, tf.int32) - x = flatten_hw(x) + pos_embed = tf.reshape(pos_embed, [ + batch_size, + pos_embed_axis_length, + pos_embed_axis_length, + pos_embed.shape[-1]] + ) - x = self.positional_encoder(x) + pos_embed = tf.image.resize( + pos_embed, + size=target_size, + method=tf.image.ResizeMethod.BICUBIC, + name="pos_embed_resize" + ) - for i in range(self.num_layer): - x = self.blocks[i](x, training=training) + pos_embed = tf.reshape(pos_embed, [batch_size, -1, pos_embed.shape[-1]]) - x = tf.reshape(x, [batch_size, height, width, channels]) + return tf.concat([extra_tokens, pos_embed], axis=1) - if self.return_endpoints: - x = [x] - return x - +class MLPBlock(tf.keras.Model): + def __init__( + self, + filters, + dropout_rate=0.0, + activation=tf.nn.gelu, + name=None + ): + super().__init__(name=name) + self.filters = filters + self.dropout_rate = dropout_rate + self.activation = activation -class PositionalEncoder(tf.keras.Model): - def __init__(self, trainable=True, name=None): - super().__init__(name=name, trainable=trainable) def build(self, input_shape): - # input_shape = batch_size, num_patches, channels - - self.num_patches = input_shape[1] - - filters = input_shape[-1] + self.dense0 = tf.keras.layers.Dense( + self.filters, + activation=self.activation, + name=f"{self.name}/dense0" + ) + + self.dense0_dropout = tf.keras.layers.Dropout(self.dropout_rate) - self.position_embedding = tf.keras.layers.Embedding( - input_dim=self.num_patches, output_dim=filters, name="pos_embedding" + self.dense1 = tf.keras.layers.Dense( + input_shape[-1], + name=f"{self.name}/dense1" + ) + + self.dense1_dropout = tf.keras.layers.Dropout( + self.dropout_rate, ) - self.built = True + super().build(input_shape) - def call(self, inputs): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = inputs + self.position_embedding(positions) + def call(self, inputs, training=None): - return encoded + x = self.dense0(inputs) + x = self.dense0_dropout(x, training=training) + x = self.dense1(x) + x = self.dense1_dropout(x, training=training) + + return x + class TransformerBlock(tf.keras.Model): def __init__(self, mlp_filters=4096, num_heads=16, dropout_rate=0.1, name=None): @@ -117,13 +114,31 @@ def build(self, input_shape): channels = input_shape[-1] - self.attention_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_0") + self.attention_norm = tf.keras.layers.LayerNormalization( + epsilon=1e-6, + name=f"{self.name}/ln1" + ) + self.attention = tf.keras.layers.MultiHeadAttention( - num_heads=self.num_head, key_dim=channels, dropout=self.dropout_rate, name="MultiHeadDotProductAttention_1" + num_heads=self.num_head, + key_dim=channels // self.num_head, + dropout=self.dropout_rate, + name=f"{self.name}/attn" ) - self.mlp_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_2") - self.mlp = MLPBlock(self.mlp_filters, self.dropout_rate, name="MlpBlock_3") + self.mlp_norm = tf.keras.layers.LayerNormalization( + epsilon=1e-6, + name=f"{self.name}/ln2" + ) + + self.mlp = MLPBlock( + self.mlp_filters, + self.dropout_rate, + name=f"{self.name}/ffn" + ) + + super().build(input_shape) + def call(self, inputs, training=None): @@ -140,30 +155,126 @@ def call(self, inputs, training=None): return x -class MLPBlock(tf.keras.Model): - def __init__(self, filters, dropout_rate=0.1, name=None): +class VisionTransformer(tf.keras.Model): + def __init__( + self, + patch_size, + num_layer, + num_head, + filters=768, + mlp_filters=4096, + dropout_rate=0.1, + use_class_token=True, + pretrain_size=224, + return_endpoints=False, + name=None + ): + super().__init__(name=name) + self.patch_size = patch_size + self.num_layer = num_layer + self.num_head = num_head self.filters = filters + self.mlp_filters = mlp_filters self.dropout_rate = dropout_rate + self.use_class_token = use_class_token + + self.pretrain_size = pretrain_size + + self.return_endpoints = return_endpoints + + + def build(self, input_shape): - self.dense0 = tf.keras.layers.Dense(self.filters, activation=tf.nn.gelu, name="Dense_0") - self.dense0_dropout = tf.keras.layers.Dropout(self.dropout_rate) + self.patch_encoder = PatchEmbed( + patch_size=(self.patch_size, self.patch_size), + embed_filters=self.filters, + name=f"{self.name}/patch_embed", + ) - self.dense1 = tf.keras.layers.Dense(input_shape[-1], name="Dense_1") - self.dense1_dropout = tf.keras.layers.Dropout(self.dropout_rate) + num_patches_axis = self.pretrain_size // self.patch_size + + self.num_patches = num_patches_axis ** 2 + self.extra_patches = 0 + + if self.use_class_token: + self.class_token = self.add_weight( + f"{self.name}/class_token", + shape=[1, 1, self.filters], + dtype=tf.float32, + initializer=tf.keras.initializers.Zeros(), + trainable=True, + ) + + self.extra_patches = 1 + + self.position_embedding = self.add_weight( + f"{self.name}/pos_embed", + shape=[1, self.num_patches + self.extra_patches, self.filters], + dtype=tf.float32, + initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), + trainable=True, + ) + + self.blocks = [] + + for i in range(self.num_layer): + self.blocks += [ + TransformerBlock( + self.mlp_filters, + num_heads=self.num_head, + dropout_rate=self.dropout_rate, + name=f"{self.name}/layers/{i}" + ) + ] + + super().build(input_shape) + def call(self, inputs, training=None): - x = self.dense0(inputs) - x = self.dense0_dropout(x, training=training) + x = inputs - x = self.dense1(x) - x = self.dense1_dropout(x, training=training) + x = self.patch_encoder(x) + + batch_size, height, width, channels = get_tensor_shape(x) + + x = flatten_hw(x) + + if self.use_class_token: + class_token = self.class_token + class_token = tf.broadcast_to( + class_token, + shape=[tf.shape(x)[0], 1, class_token.shape[-1]], + name="class_token_batch_broadcast" + ) + + x = tf.concat([class_token, x], axis=1) + + position_embedding = resize_pos_embed( + pos_embed=self.position_embedding, + target_size=(height, width), + num_extra_tokens=self.extra_patches, + ) + + x = tf.add(x, position_embedding, name="position_embedding_add") + + for i in range(self.num_layer): + x = self.blocks[i](x, training=training) + + if self.use_class_token: + x = x[:, self.extra_patches: ] # remove class token + + x = tf.reshape(x, [batch_size, height, width, channels]) + + if self.return_endpoints: + x = [x] return x + def ViT16L(return_endpoints=False): @@ -187,6 +298,7 @@ def ViT16B(return_endpoints=False): num_head=12, filters=768, mlp_filters=3072, + pretrain_size=384, return_endpoints=return_endpoints, name="ViT-B_16" ) @@ -200,6 +312,7 @@ def ViT16S(return_endpoints=False): num_head=6, filters=384, mlp_filters=1536, + pretrain_size=384, return_endpoints=return_endpoints, name="ViT-S_16" ) diff --git a/layers/common_layers.py b/layers/common_layers.py index 7b57370..b31a6e8 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -64,17 +64,24 @@ def __init__(self, patch_size=(4, 4), embed_filters=96, norm_layer=None, name=No super().__init__(name=name) self.patch_size = patch_size - self.embed_dim = embed_filters + self.embed_filters = embed_filters + + self.norm_layer = norm_layer + + def build(self, input_shape): self.proj = tf.keras.layers.Conv2D( - embed_filters, kernel_size=patch_size, strides=patch_size, name=f"{name}/proj" + self.embed_filters, + kernel_size=self.patch_size, + strides=self.patch_size, + name=f"{self.name}/projection" ) - if norm_layer is not None: - self.norm = norm_layer(epsilon=1e-5, name=f"{name}/norm") + + if self.norm_layer is not None: + self.norm = self.norm_layer(epsilon=1e-5, name=f"{self.name}/norm") else: self.norm = None - def build(self, input_shape): super().build(input_shape) def call(self, x): @@ -93,7 +100,13 @@ def call(self, x): x = self.proj(x) x = tf.reshape( - x, shape=[-1, (padded_height // self.patch_size[0]), (padded_width // self.patch_size[0]), self.embed_dim] + x, + shape=[ + -1, + (padded_height // self.patch_size[0]), + (padded_width // self.patch_size[0]), + self.embed_filters + ] ) if self.norm is not None: