Skip to content

Commit

Permalink
Fix broadcat, probabilitiy, etc (PaddlePaddle#430)
Browse files Browse the repository at this point in the history
Fix broadcat, probabilitiy, etc
Run pre-commit

Co-authored-by: LokeZhou <aishenghuoaiqq@163.com>
Co-authored-by: wangguanzhong <jerrywgz@126.com>
  • Loading branch information
3 people authored Feb 26, 2024
1 parent e93c1fc commit 9b349ac
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 463 deletions.
124 changes: 63 additions & 61 deletions paddlemix/models/audioldm2/audiomae/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@

import paddle
import paddle.nn as nn
from ..utils import to_2tuple, DropPath, Mlp

from ..clap_module.htsat_model import SwinTransformerBlock
from ..utils import DropPath, Mlp, to_2tuple

class Attention(nn.Layer):

class Attention(nn.Layer):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Layer = nn.LayerNorm,
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Layer = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.fused_attn = False

self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
Expand All @@ -53,35 +54,38 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor:

if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(
q, k, v,
dropout=self.attn_drop.p if self.training else 0.,
q,
k,
v,
dropout=self.attn_drop.p if self.training else 0.0,
)[0]
else:
q = q * self.scale
k_perm = list(range(k.dim()))
new_perm = k_perm
new_perm[-2],new_perm[-1] = k_perm[-1],k_perm[-2]
new_perm[-2], new_perm[-1] = k_perm[-1], k_perm[-2]
attn = q @ k.transpose(new_perm)
attn = nn.functional.softmax(attn,axis=-1)
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = attn @ v

x_perm = list(range(x.dim()))
new_perm = x_perm
new_perm[1],new_perm[2] = x_perm[2],x_perm[1]
new_perm[1], new_perm[2] = x_perm[2], x_perm[1]
x = x.transpose(new_perm).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x



class LayerScale(nn.Layer):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
tmp = init_values * paddle.ones(dim)
self.gamma = paddle.create_parameter(shape=tmp.shape,
dtype=tmp.dtype,
default_initializer=nn.initializer.Assign(tmp))
self.gamma = paddle.create_parameter(
shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)
)
self.gamma.stop_gradient = False

def forward(self, x):
Expand All @@ -91,39 +95,40 @@ def forward(self, x):
else:
return x * self.gamma
# return paddle.multiply(x, self.gamma) if self.inplace else x * self.gamma

class Block(nn.Layer):


class Block(nn.Layer):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
init_values=None,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x



class PatchEmbed_org(nn.Layer):
"""Image to Patch Embedding"""

Expand All @@ -137,9 +142,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
self.patch_size = patch_size
self.num_patches = num_patches

self.proj = nn.Conv2D(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x):
B, C, H, W = x.shape
Expand All @@ -149,7 +152,8 @@ def forward(self, x):
x = self.proj(x)
y = x.flatten(2).transpose([0, 2, 1])
return y



class MaskedAutoencoderViT(nn.Layer):
"""Masked Autoencoder with VisionTransformer backbone"""

Expand Down Expand Up @@ -194,19 +198,19 @@ def __init__(
# MAE encoder specifics
self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
self.use_custom_patch = use_custom_patch

num_patches = self.patch_embed.num_patches
tmp = paddle.zeros([1, 1, embed_dim])
self.cls_token = paddle.create_parameter(shape=tmp.shape,
dtype=tmp.dtype,
default_initializer=nn.initializer.Assign(tmp))
self.cls_token = paddle.create_parameter(
shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)
)
self.cls_token.stop_gradient = False

# self.split_pos = split_pos # not useful
tmp = paddle.zeros([1, num_patches + 1, embed_dim])
self.pos_embed = paddle.create_parameter(shape=tmp.shape,
dtype=tmp.dtype,
default_initializer=nn.initializer.Assign(tmp)) # fixed sin-cos embedding
self.pos_embed = paddle.create_parameter(
shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)
) # fixed sin-cos embedding
self.pos_embed.stop_gradient = not pos_trainable

self.encoder_depth = depth
Expand All @@ -230,23 +234,21 @@ def __init__(
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias_attr=True)

tmp = paddle.zeros([1, 1, decoder_embed_dim])
self.mask_token = paddle.create_parameter(shape=tmp.shape,
dtype=tmp.dtype,
default_initializer=nn.initializer.Assign(tmp))
self.mask_token = paddle.create_parameter(
shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)
)
self.mask_token.stop_gradient = False

tmp = paddle.zeros([1, num_patches + 1, decoder_embed_dim])
self.decoder_pos_embed = paddle.create_parameter(shape=tmp.shape,
dtype=tmp.dtype,
default_initializer=nn.initializer.Assign(tmp)) # fixed sin-cos embedding
self.decoder_pos_embed = paddle.create_parameter(
shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)
) # fixed sin-cos embedding
self.decoder_pos_embed.stop_gradient = not pos_trainable

self.no_shift = no_shift

self.decoder_mode = decoder_mode
if (
self.use_custom_patch
): # overlapped patches as in AST. Similar performance yet compute heavy
if self.use_custom_patch: # overlapped patches as in AST. Similar performance yet compute heavy
window_size = (6, 6)
feat_size = (102, 12)
else:
Expand Down Expand Up @@ -280,7 +282,7 @@ def __init__(
)
self.decoder_blocks = nn.LayerList(decoder_modules)
else:
# Transfomer
# Transformer
self.decoder_blocks = nn.LayerList(
[
Block(
Expand Down Expand Up @@ -326,17 +328,17 @@ def __init__(
def forward_encoder_no_mask(self, x):
# embed patches
x = self.patch_embed(x)

# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]

# masking: length -> length * mask_ratio
# x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand([x.shape[0], -1, -1])
x = paddle.concat((cls_tokens, x), axis=1)

# apply Transformer blocks
contextual_embs = []
for n, blk in enumerate(self.blocks):
Expand Down
Loading

0 comments on commit 9b349ac

Please sign in to comment.