diff --git a/./cct.py b/./cct_keras_core.py index e286998..99a0008 100644 --- a/./cct.py +++ b/./cct_keras_core.py @@ -1,8 +1,9 @@ """ Title: Compact Convolutional Transformers Author: [Sayak Paul](https://twitter.com/RisingSayak) +Converted to Keras Core by: [Muhammad Anas Raza](https://anasrz.com) Date created: 2021/06/30 -Last modified: 2021/06/30 +Last modified: 2023/07/17 Description: Compact Convolutional Transformers for efficient image classification. Accelerator: GPU """ @@ -30,23 +31,21 @@ from François Chollet's book *Deep Learning with Python*. This example uses code snippets from another example, [Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/). -This example requires TensorFlow 2.5 or higher, as well as TensorFlow Addons, which can -be installed using the following command: -""" -"""shell -pip install -U -q tensorflow-addons """ + """ ## Imports """ -from tensorflow.keras import layers -from tensorflow import keras +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + +from keras_core import layers +import keras_core as keras import matplotlib.pyplot as plt -import tensorflow_addons as tfa import tensorflow as tf import numpy as np @@ -167,6 +166,25 @@ class CCTTokenizer(layers.Layer): else: return None +""" +## Sequence Pooling +Another recipe introduced in CCT is attention pooling or sequence pooling. In ViT, only +the feature map corresponding to the class token is pooled and is then used for the +subsequent classification task (or any other downstream task). +""" + +class SequencePooling(layers.Layer): + def __init__(self): + super().__init__() + self.attention = layers.Dense(1) + + def call(self, x): + attention_weights = tf.nn.softmax(self.attention(x), axis=1) + weighted_representation = tf.matmul( + attention_weights, x, transpose_a=True + ) + return tf.squeeze(weighted_representation, -2) + """ ## Stochastic depth for regularization @@ -230,14 +248,13 @@ data_augmentation = keras.Sequential( """ ## The final CCT model -Another recipe introduced in CCT is attention pooling or sequence pooling. In ViT, only -the feature map corresponding to the class token is pooled and is then used for the -subsequent classification task (or any other downstream task). In CCT, outputs from the -Transformers encoder are weighted and then passed on to the final task-specific layer (in +In CCT, outputs from the Transformers encoder are weighted and then passed on to the final task-specific layer (in this example, we do classification). """ + + def create_cct_model( image_size=image_size, input_shape=input_shape, @@ -290,11 +307,7 @@ def create_cct_model( # Apply sequence pooling. representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches) - attention_weights = tf.nn.softmax(layers.Dense(1)(representation), axis=1) - weighted_representation = tf.matmul( - attention_weights, representation, transpose_a=True - ) - weighted_representation = tf.squeeze(weighted_representation, -2) + weighted_representation = SequencePooling()(representation) # Classify outputs. logits = layers.Dense(num_classes)(weighted_representation) @@ -309,7 +322,7 @@ def create_cct_model( def run_experiment(model): - optimizer = tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001) + optimizer = keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001) model.compile( optimizer=optimizer, @@ -322,7 +335,7 @@ def run_experiment(model): ], ) - checkpoint_filepath = "/tmp/checkpoint" + checkpoint_filepath = "/tmp/checkpoint.weights.h5" checkpoint_callback = keras.callbacks.ModelCheckpoint( checkpoint_filepath, monitor="val_accuracy",