Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix position encoder in examples/.../token_learner.py #727

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions examples/keras_io/vision/token_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Authors: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)
Converted to Keras Core by: [Muhammad Anas Raza](https://anasrz.com)
Date created: 2021/12/10
Last modified: 2023/07/18
Last modified: 2023/08/14
Description: Adaptively generating a smaller number of tokens for Vision Transformers.
Accelerator: GPU
"""
Expand Down Expand Up @@ -165,19 +165,25 @@
"""


def position_embedding(
projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
):
# Build the positions.
positions = ops.arange(start=0, stop=num_patches, step=1)
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)

# Encode the positions with an Embedding layer.
encoded_positions = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)(positions)
def call(self, patch):
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
encoded = patch + self.position_embedding(positions)
return encoded

# Add encoded positions to the projected patches.
return projected_patches + encoded_positions
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
return config


"""
Expand Down Expand Up @@ -337,7 +343,9 @@ def create_vit_classifier(
) # (B, number_patches, projection_dim)

# Add positional embeddings to the projected patches.
encoded_patches = position_embedding(
encoded_patches = PatchEncoder(
num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
)(
projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = layers.Dropout(0.1)(encoded_patches)
Expand Down