Skip to content

Commit

Permalink
allow for overlapping super-pixel data, as in the token2token paper, …
Browse files Browse the repository at this point in the history
…by setting the unfold_args keyword argument
  • Loading branch information
lucidrains committed Mar 10, 2021
1 parent 3bbd79f commit 9e40e30
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'transformer-in-transformer',
packages = find_packages(),
version = '0.0.7',
version = '0.0.8',
license='MIT',
description = 'Transformer in Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
22 changes: 17 additions & 5 deletions transformer_in_transformer/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def divisible_by(val, divisor):
return (val % divisor) == 0

def unfold_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# classes

class PreNorm(nn.Module):
Expand Down Expand Up @@ -86,25 +92,31 @@ def __init__(
heads = 8,
dim_head = 64,
ff_dropout = 0.,
attn_dropout = 0.
attn_dropout = 0.,
unfold_args = None
):
super().__init__()
assert divisible_by(image_size, patch_size), 'image size must be divisible by patch size'
assert divisible_by(patch_size, pixel_size), 'patch size must be divisible by pixel size for now'

num_patch_tokens = (image_size // patch_size) ** 2
pixel_width = patch_size // pixel_size
num_pixels = pixel_width ** 2

self.image_size = image_size
self.patch_size = patch_size
self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))

unfold_args = default(unfold_args, (pixel_size, pixel_size, 0))
unfold_args = (*unfold_args, 0) if len(unfold_args) == 2 else unfold_args
kernel_size, stride, padding = unfold_args

pixel_width = unfold_output_size(patch_size, kernel_size, stride, padding)
num_pixels = pixel_width ** 2

self.to_pixel_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1 = patch_size, p2 = patch_size),
nn.Unfold(pixel_size, stride = pixel_size),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = padding),
Rearrange('... c n -> ... n c'),
nn.Linear(3 * pixel_size ** 2, pixel_dim)
nn.Linear(3 * kernel_size ** 2, pixel_dim)
)

self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
Expand Down

0 comments on commit 9e40e30

Please sign in to comment.