Skip to content

Commit

Permalink
Add feature_grad_mult argument to HuBERTPretrainModel (#2335)
Browse files Browse the repository at this point in the history
Summary:
In Wav2Vec2 and HuBERT model training, the convolutional feature extraction layers use `group_norm` for normalization in `Base` model, while they use `layer_norm` in `Large` and `XLarge` models. For `Base` model, the gradients of feature extraction layers will be unstable in pre-training, thus we need to scale down the gradient by multiplying 0.1.

In this PR, we add such argument to `HuBERTPretrainModel` to control the gradient of feature extractor layers. We also put the argument in the factory functions (`hubert_pretrain_base`, `hubert_pretrain_large`, and `hubert_pretrain_xlarge`. The reason is in finetuning, the feature extractor's parameters are fixed, we can multiply the gradient with 0.0 to avoid back propagating gradients.

Pull Request resolved: #2335

Reviewed By: xiaohui-zhang, mthrok

Differential Revision: D35646928

Pulled By: nateanl

fbshipit-source-id: 6a9563e227aac6e3127b634357946d860f26c994
  • Loading branch information
nateanl authored and facebook-github-bot committed May 18, 2022
1 parent c6a376c commit 647f28e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
12 changes: 12 additions & 0 deletions torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,3 +1037,15 @@ def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> T
label_u = label[mask_u]
logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings)
return logit_m, logit_u


class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res

@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
44 changes: 40 additions & 4 deletions torchaudio/models/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,28 @@ class HuBERTPretrainModel(Module):
logit_generator (torch.nn.Module):
Logit generator that predicts the logits of the masked and unmasked inputs.
feature_grad_mult (float or None):
The factor to scale the convolutional feature extraction layer gradients by.
If ``None``, the gradients of feature extraction layers are not affected.
The scale factor will not affect the forward pass.
"""

def __init__(
self,
wav2vec2: Wav2Vec2Model,
mask_generator: Module,
logit_generator: Module,
feature_grad_mult: Optional[float],
):
super().__init__()
self.wav2vec2 = wav2vec2
self.mask_generator = mask_generator
self.logit_generator = logit_generator
assert (
feature_grad_mult is None or 0.0 < feature_grad_mult < 1.0
), f"The value of `feature_grad_mult` must be ``None`` or between (0, 1). Found {feature_grad_mult}"
self.feature_grad_mult = feature_grad_mult

def forward(
self,
Expand Down Expand Up @@ -184,6 +194,8 @@ def forward(
Shape: `(1,)`.
"""
x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
if self.feature_grad_mult is not None and self.feature_grad_mult < 1.0:
x = components.GradMultiply.apply(x, self.feature_grad_mult)
features_pen = x.float().pow(2).mean()
if lengths is not None:
padding_mask = components._get_padding_mask(x, lengths)
Expand Down Expand Up @@ -712,6 +724,7 @@ def hubert_pretrain_model(
skip_nomask: bool,
num_classes: int,
final_dim: int,
feature_grad_mult: Optional[float],
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, mask_prob: float, mask_selection: str, mask_other: float, mask_length: int, no_mask_overlap: bool, mask_min_space: int, mask_channel_prob: float, mask_channel_selection: str, mask_channel_other: float, mask_channel_length: int, no_mask_channel_overlap: bool, mask_channel_min_space: int, skip_masked: bool, skip_nomask: bool, num_classes: int, final_dim: int) -> torchaudio.models.HuBERTPretrainModel
Expand Down Expand Up @@ -910,6 +923,12 @@ def hubert_pretrain_model(
This option corresponds to ``final_dim`` from ``fairseq``.
feature_grad_mult (float or None):
The factor to scale the convolutional feature extraction layer gradients by.
The scale factor will not affect the forward pass.
This option corresponds to ``feature_grad_mult`` from ``fairseq``.
Returns:
HuBERTPretrainModel:
The resulting model.
Expand Down Expand Up @@ -958,7 +977,12 @@ def hubert_pretrain_model(
skip_masked,
skip_nomask,
)
return HuBERTPretrainModel(wav2vec2=wav2vec2, mask_generator=mask_generator, logit_generator=logit_generator)
return HuBERTPretrainModel(
wav2vec2=wav2vec2,
mask_generator=mask_generator,
logit_generator=logit_generator,
feature_grad_mult=feature_grad_mult,
)


def hubert_pretrain_base(
Expand All @@ -970,10 +994,11 @@ def hubert_pretrain_base(
mask_prob: float = 0.8,
mask_channel_prob: float = 0.0,
mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = 0.1,
num_classes: int = 100,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, num_classes: int = 100) -> torchaudio.models.HuBERTPretrainModel
"""hubert_pretrain_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, feature_grad_mult: Optional[float] = 0.1, num_classes: int = 100) -> torchaudio.models.HuBERTPretrainModel
Build HuBERTPretrainModel model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Expand All @@ -994,6 +1019,8 @@ def hubert_pretrain_base(
See :py:func:`hubert_pretrain_model`.
mask_channel_length (int):
See :py:func:`hubert_pretrain_model`.
feature_grad_mult (float or None):
See :py:func:`hubert_pretrain_model`.
num_classes (int, optional):
See :py:func:`hubert_pretrain_model`.
Expand Down Expand Up @@ -1033,6 +1060,7 @@ def hubert_pretrain_base(
skip_nomask=False,
num_classes=num_classes,
final_dim=256,
feature_grad_mult=feature_grad_mult,
)


Expand All @@ -1045,9 +1073,10 @@ def hubert_pretrain_large(
mask_prob: float = 0.8,
mask_channel_prob: float = 0.0,
mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = None,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10) -> torchaudio.models.HuBERTPretrainModel
"""hubert_pretrain_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, feature_grad_mult: Optional[float] = None) -> torchaudio.models.HuBERTPretrainModel
Build HuBERTPretrainModel model for pre-training with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Expand All @@ -1068,6 +1097,8 @@ def hubert_pretrain_large(
See :py:func:`hubert_pretrain_model`.
mask_channel_length (int):
See :py:func:`hubert_pretrain_model`.
feature_grad_mult (float or None):
See :py:func:`hubert_pretrain_model`.
Returns:
HuBERTPretrainModel:
Expand Down Expand Up @@ -1105,6 +1136,7 @@ def hubert_pretrain_large(
skip_nomask=False,
num_classes=500,
final_dim=768,
feature_grad_mult=feature_grad_mult,
)


Expand All @@ -1117,9 +1149,10 @@ def hubert_pretrain_xlarge(
mask_prob: float = 0.8,
mask_channel_prob: float = 0.0,
mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = None,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10) -> torchaudio.models.HuBERTPretrainModel
"""hubert_pretrain_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, feature_grad_mult: Optional[float] = None) -> torchaudio.models.HuBERTPretrainModel
Build HuBERTPretrainModel model for pre-training with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Expand All @@ -1140,6 +1173,8 @@ def hubert_pretrain_xlarge(
See :py:func:`hubert_pretrain_model`.
mask_channel_length (int):
See :py:func:`hubert_pretrain_model`.
feature_grad_mult (float or None):
See :py:func:`hubert_pretrain_model`.
Returns:
HuBERTPretrainModel:
Expand Down Expand Up @@ -1177,4 +1212,5 @@ def hubert_pretrain_xlarge(
skip_nomask=False,
num_classes=500,
final_dim=1024,
feature_grad_mult=feature_grad_mult,
)

0 comments on commit 647f28e

Please sign in to comment.