From 9e40e3047e78fe402417bd87270baee996836ed7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 9 Mar 2021 21:40:35 -0800 Subject: [PATCH] allow for overlapping super-pixel data, as in the token2token paper, by setting the unfold_args keyword argument --- setup.py | 2 +- transformer_in_transformer/tnt.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index cc7e9b9..6b56f71 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'transformer-in-transformer', packages = find_packages(), - version = '0.0.7', + version = '0.0.8', license='MIT', description = 'Transformer in Transformer - Pytorch', author = 'Phil Wang', diff --git a/transformer_in_transformer/tnt.py b/transformer_in_transformer/tnt.py index ac37c12..17041f4 100644 --- a/transformer_in_transformer/tnt.py +++ b/transformer_in_transformer/tnt.py @@ -10,9 +10,15 @@ def exists(val): return val is not None +def default(val, d): + return val if exists(val) else d + def divisible_by(val, divisor): return (val % divisor) == 0 +def unfold_output_size(image_size, kernel_size, stride, padding): + return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) + # classes class PreNorm(nn.Module): @@ -86,25 +92,31 @@ def __init__( heads = 8, dim_head = 64, ff_dropout = 0., - attn_dropout = 0. + attn_dropout = 0., + unfold_args = None ): super().__init__() assert divisible_by(image_size, patch_size), 'image size must be divisible by patch size' assert divisible_by(patch_size, pixel_size), 'patch size must be divisible by pixel size for now' num_patch_tokens = (image_size // patch_size) ** 2 - pixel_width = patch_size // pixel_size - num_pixels = pixel_width ** 2 self.image_size = image_size self.patch_size = patch_size self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim)) + unfold_args = default(unfold_args, (pixel_size, pixel_size, 0)) + unfold_args = (*unfold_args, 0) if len(unfold_args) == 2 else unfold_args + kernel_size, stride, padding = unfold_args + + pixel_width = unfold_output_size(patch_size, kernel_size, stride, padding) + num_pixels = pixel_width ** 2 + self.to_pixel_tokens = nn.Sequential( Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1 = patch_size, p2 = patch_size), - nn.Unfold(pixel_size, stride = pixel_size), + nn.Unfold(kernel_size = kernel_size, stride = stride, padding = padding), Rearrange('... c n -> ... n c'), - nn.Linear(3 * pixel_size ** 2, pixel_dim) + nn.Linear(3 * kernel_size ** 2, pixel_dim) ) self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))