From 8f0c8bc1627b3019696f7f3383c6b868dead4207 Mon Sep 17 00:00:00 2001 From: guillaumebaquiast Date: Sun, 13 Aug 2023 19:15:34 +0200 Subject: [PATCH] position encoding token_learner --- examples/keras_io/vision/token_learner.py | 34 ++++++++++++++--------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/examples/keras_io/vision/token_learner.py b/examples/keras_io/vision/token_learner.py index aaad6a1d2..96a8f0c5f 100644 --- a/examples/keras_io/vision/token_learner.py +++ b/examples/keras_io/vision/token_learner.py @@ -3,7 +3,7 @@ Authors: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution) Converted to Keras Core by: [Muhammad Anas Raza](https://anasrz.com) Date created: 2021/12/10 -Last modified: 2023/07/18 +Last modified: 2023/08/14 Description: Adaptively generating a smaller number of tokens for Vision Transformers. Accelerator: GPU """ @@ -165,19 +165,25 @@ """ -def position_embedding( - projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM -): - # Build the positions. - positions = ops.arange(start=0, stop=num_patches, step=1) +class PatchEncoder(layers.Layer): + def __init__(self, num_patches, projection_dim): + super().__init__() + self.num_patches = num_patches + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) - # Encode the positions with an Embedding layer. - encoded_positions = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - )(positions) + def call(self, patch): + positions = ops.expand_dims( + ops.arange(start=0, stop=self.num_patches, step=1), axis=0 + ) + encoded = patch + self.position_embedding(positions) + return encoded - # Add encoded positions to the projected patches. - return projected_patches + encoded_positions + def get_config(self): + config = super().get_config() + config.update({"num_patches": self.num_patches}) + return config """ @@ -337,7 +343,9 @@ def create_vit_classifier( ) # (B, number_patches, projection_dim) # Add positional embeddings to the projected patches. - encoded_patches = position_embedding( + encoded_patches = PatchEncoder( + num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM + )( projected_patches ) # (B, number_patches, projection_dim) encoded_patches = layers.Dropout(0.1)(encoded_patches)