Skip to content

Commit

Permalink
Add cross attn to convnext
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Sep 21, 2023
1 parent 7cfa480 commit 844760d
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 54 deletions.
2 changes: 1 addition & 1 deletion fish_diffusion/datasets/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class NaiveTTSDataset(NaiveDataset):
]

collating_pipeline = [
dict(type="FilterByLength", key="mel", dim=0, min_length=1, max_length=1024),
dict(type="FilterByLength", key="mel", dim=0, min_length=1, max_length=2048),
dict(type="ListToDict"),
dict(
type="PadStack",
Expand Down
150 changes: 136 additions & 14 deletions fish_diffusion/modules/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from fish_diffusion.modules.wavenet import DiffusionEmbedding
Expand Down Expand Up @@ -55,14 +56,22 @@ def forward(
x: torch.Tensor,
condition: Optional[torch.Tensor] = None,
diffusion_step: Optional[torch.Tensor] = None,
x_mask: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = x

x = (
x
+ self.diffusion_step_projection(diffusion_step)
+ self.condition_projection(condition)
)
if diffusion_step is not None:
x = x + self.diffusion_step_projection(diffusion_step)

if condition is not None:
if condition_mask is not None:
condition = condition.masked_fill(condition_mask[:, None, :], 0.0)

x = x + self.condition_projection(condition)

if x_mask is not None:
x = x.masked_fill(x_mask[:, None, :], 0.0)

x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
Expand All @@ -75,9 +84,67 @@ def forward(
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)

x = residual + x

if x_mask is not None:
x = x.masked_fill(x_mask[:, None, :], 0.0)

return x


class CrossAttentionBlock(nn.TransformerDecoderLayer):
def __init__(
self,
dim: int,
intermediate_dim: int,
nhead: int = 8,
):
super().__init__(
d_model=dim,
nhead=nhead,
dim_feedforward=intermediate_dim,
activation="gelu",
batch_first=True,
)

self.diffusion_step_projection = nn.Conv1d(dim, dim, 1)
self.register_buffer("positional_embedding", self.get_embedding(dim))

def get_embedding(self, embedding_dim, num_embeddings=4096):
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)

return emb

def forward(self, x, condition, diffusion_step, x_mask=None, condition_mask=None):
if diffusion_step is not None:
x = x + self.diffusion_step_projection(diffusion_step)

# Apply positional encoding to both x and condition
x = x.transpose(1, 2)
condition = condition.transpose(1, 2)

x = x + self.positional_embedding[: x.size(1)][None]
condition = condition + self.positional_embedding[: condition.size(1)][None]

return (
super()
.forward(
tgt=x,
memory=condition,
tgt_key_padding_mask=x_mask,
memory_key_padding_mask=condition_mask,
)
.transpose(1, 2)
)


class ConvNext(nn.Module):
def __init__(
self,
Expand All @@ -88,6 +155,8 @@ def __init__(
num_layers=20,
dilation_cycle=4,
gradient_checkpointing=False,
cross_attention=False,
cross_every_n_layers=5,
):
super(ConvNext, self).__init__()

Expand All @@ -104,30 +173,42 @@ def __init__(
nn.Conv1d(dim * mlp_factor, dim, 1),
)

self.residual_layers = nn.ModuleList(
[
self.residual_layers = nn.ModuleList([])

for i in range(num_layers):
if cross_attention and i % cross_every_n_layers == 0:
self.residual_layers.append(
CrossAttentionBlock(
dim=dim,
intermediate_dim=dim * mlp_factor,
)
)

self.residual_layers.append(
ConvNeXtBlock(
dim=dim,
intermediate_dim=dim * mlp_factor,
dilation=2 ** (i % dilation_cycle),
)
for i in range(num_layers)
]
)
)

self.output_projection = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size=1),
nn.GELU(),
nn.Conv1d(dim, mel_channels, kernel_size=1),
)

self.gradient_checkpointing = gradient_checkpointing
self.cross_attention = cross_attention

def forward(self, x, diffusion_step, conditioner):
def forward(self, x, diffusion_step, conditioner, x_mask=None, condition_mask=None):
"""
:param x: [B, M, T]
:param diffusion_step: [B,]
:param conditioner: [B, M, T]
:param conditioner: [B, M, E]
:param x_mask: [B, T] bool mask
:param condition_mask: [B, E] bool mask
:return:
"""

Expand All @@ -145,14 +226,55 @@ def forward(self, x, diffusion_step, conditioner):
diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(-1)
condition = self.conditioner_projection(conditioner)

if x_mask is not None:
x = x.masked_fill(x_mask[:, None, :], 0.0)

if condition_mask is not None:
condition = condition.masked_fill(condition_mask[:, None, :], 0.0)

for layer in self.residual_layers:
is_cross_layer = isinstance(layer, CrossAttentionBlock)
temp_condition = (
condition
if ((self.cross_attention is False) or is_cross_layer)
else None
)

if self.training and self.gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
layer, x, condition, diffusion_step
layer, x, temp_condition, diffusion_step, x_mask, condition_mask
)
else:
x = layer(x, condition, diffusion_step)
x = layer(x, temp_condition, diffusion_step, x_mask, condition_mask)

x = self.output_projection(x) # [B, 128, T]
if x_mask is not None:
x = x.masked_fill(x_mask[:, None, :], 0.0)

return x[:, None] if use_4_dim else x


if __name__ == "__main__":
import torch

gpu_memory_usage = torch.cuda.memory_allocated() / 1024**3
torch.cuda.empty_cache()
torch.cuda.synchronize()

model = ConvNext(
cross_attention=True,
gradient_checkpointing=True,
).cuda()
x = torch.randn(8, 128, 1024).cuda()
diffusion_step = torch.randint(0, 1000, (8,)).cuda()
conditioner = torch.randn(8, 256, 256).cuda()
x_mask = torch.randint(0, 2, (8, 1024)).bool().cuda()
condition_mask = torch.randint(0, 2, (8, 256)).bool().cuda()
y = model(x, diffusion_step, conditioner, x_mask, condition_mask)
print(y.shape)

torch.cuda.empty_cache()
torch.cuda.synchronize()

gpu_memory_usage = torch.cuda.memory_allocated() / 1024**3 - gpu_memory_usage
print(f"GPU memory usage: {gpu_memory_usage:.2f} GB")
44 changes: 5 additions & 39 deletions fish_diffusion/modules/feature_extractors/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,16 @@ class BertTokenizer(BaseFeatureExtractor):
def __init__(
self,
model_name: str,
transcription_path: str,
label_suffix: str = ".txt",
):
super().__init__()

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.transcription_path = Path(transcription_path)

self.transcriptions = self._load_transcriptions(transcription_path)

def _load_transcriptions(self, transcription_path: str):
results = {}

for i in open(transcription_path):
id, text = i.split("|")

results[id] = text.strip()

return results
self.label_suffix = label_suffix

@torch.no_grad()
def forward(self, audio_path: Path):
id = str(
audio_path.absolute().relative_to(self.transcription_path.parent.absolute())
)
text = self.transcriptions[id]

data = self.tokenizer.encode_plus(text, return_offsets_mapping=True)

input_ids = data["input_ids"]
offset_mapping = data["offset_mapping"]

# Aligning input_ids with offset_mapping
new_input_ids = []
for input_id, (start, end) in zip(input_ids, offset_mapping):
length = end - start
new_input_ids.extend(
[input_id] + [self.tokenizer.pad_token_id] * (length - 1)
)

# Adding <pad> between each word
input_ids = torch.tensor(new_input_ids, dtype=torch.long)
new_input_ids = torch.tensor(
[self.tokenizer.pad_token_id] * (len(input_ids) * 2 - 1), dtype=torch.long
)
new_input_ids[::2] = input_ids
transcript = audio_path.with_suffix(self.label_suffix).read_text().strip()
input_ids = self.tokenizer.encode(transcript, return_tensors="pt")

return new_input_ids
return input_ids

0 comments on commit 844760d

Please sign in to comment.