From 01ada04ebc39680ae57128ef88c93635f826697f Mon Sep 17 00:00:00 2001 From: Alan Fang <614478391@qq.com> Date: Mon, 8 Apr 2024 01:28:55 +0800 Subject: [PATCH] LoRA support (#2049) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support lora for v3.0.1 * format code and update lora attention && encoder * fix bug when lora_list is None --------- Co-authored-by: Xingchen Song(宋星辰) --- wenet/bin/recognize.py | 5 + wenet/bin/train.py | 4 +- wenet/finetune/lora/attention.py | 115 +++++++++++ wenet/finetune/lora/encoder.py | 227 +++++++++++++++++++++ wenet/finetune/lora/layers.py | 338 +++++++++++++++++++++++++++++++ wenet/finetune/lora/utils.py | 63 ++++++ wenet/utils/init_model.py | 12 ++ wenet/utils/train_utils.py | 34 ++++ 8 files changed, 797 insertions(+), 1 deletion(-) create mode 100644 wenet/finetune/lora/attention.py create mode 100644 wenet/finetune/lora/encoder.py create mode 100644 wenet/finetune/lora/layers.py create mode 100644 wenet/finetune/lora/utils.py diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index f8a11b922..3779b74ec 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -171,6 +171,11 @@ def get_args(): default=0.0, help='''The higher the score, the greater the degree of bias using decoding-graph for biasing''') + + parser.add_argument('--use_lora', + type=bool, + default=False, + help='''Whether to use lora for biasing''') args = parser.parse_args() print(args) return args diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 1ddfe0435..a8a18fdcf 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -34,7 +34,8 @@ add_deepspeed_args, add_trace_args, init_distributed, init_dataset_and_dataloader, check_modify_and_save_config, init_optimizer_and_scheduler, init_scaler, trace_and_print_model, - wrap_cuda_model, init_summarywriter, save_model, log_per_epoch) + wrap_cuda_model, init_summarywriter, save_model, log_per_epoch, + add_lora_args) def get_args(): @@ -46,6 +47,7 @@ def get_args(): parser = add_model_args(parser) parser = add_dataset_args(parser) parser = add_ddp_args(parser) + parser = add_lora_args(parser) parser = add_deepspeed_args(parser) parser = add_fsdp_args(parser) parser = add_trace_args(parser) diff --git a/wenet/finetune/lora/attention.py b/wenet/finetune/lora/attention.py new file mode 100644 index 000000000..fb758bdcb --- /dev/null +++ b/wenet/finetune/lora/attention.py @@ -0,0 +1,115 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alan (alanfangemail@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition with lora.""" + +from typing import Optional, List + +import torch +from torch import nn + +from wenet.transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +import wenet.finetune.lora.layers as lora + + +class LoRAMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with lora. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + use_sdpa: bool = False, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None): + """Construct an MultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, + value_bias, use_sdpa) + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_out = lora.Linear( + n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, + lora_dropout=lora_dropout + ) if lora_list and "o" in lora_list else nn.Linear(n_feat, n_feat) + + lora_qkv_dict = { + "q": lora_list and "q" in lora_list, + "k": lora_list and "k" in lora_list, + "v": lora_list and "v" in lora_list + } + bias_dict = {"q": query_bias, "k": key_bias, "v": value_bias} + + for key, value in lora_qkv_dict.items(): + setattr(self, f"linear_{key}", + lora.Linear(n_feat, n_feat, r=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + bias=bias_dict[key]) + if value else nn.Linear(n_feat, n_feat, bias_dict[key])) + self.dropout = nn.Dropout(p=dropout_rate) + + +class LoRARelPositionMultiHeadedAttention(LoRAMultiHeadedAttention, + RelPositionMultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + use_sdpa: bool = False, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, + value_bias, use_sdpa, lora_rank, lora_alpha, + lora_dropout, lora_list) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) diff --git a/wenet/finetune/lora/encoder.py b/wenet/finetune/lora/encoder.py new file mode 100644 index 000000000..fa8e3183a --- /dev/null +++ b/wenet/finetune/lora/encoder.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alan (alanfangemail@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition with lora.""" + +from typing import Optional, List + +import torch + +from wenet.transformer.convolution import ConvolutionModule +from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.utils.class_utils import ( + WENET_MLP_CLASSES, + WENET_ACTIVATION_CLASSES, +) +from wenet.finetune.lora.utils import WENET_LORA_ATTENTION_CLASSES + + +class LoRATransformerEncoder(TransformerEncoder): + """Transformer encoder module with lora.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, query_bias, key_bias, + value_bias, mlp_bias, activation_type, + gradient_checkpointing, use_sdpa, mlp_type, + layer_norm_type, norm_eps, n_kv_head, head_dim) + activation = WENET_ACTIVATION_CLASSES[activation_type]() + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + WENET_LORA_ATTENTION_CLASSES["selfattn"](attention_heads, + output_size, + attention_dropout_rate, + query_bias, key_bias, + value_bias, use_sdpa, + n_kv_head, head_dim, + lora_rank, lora_alpha, + lora_dropout, + lora_list), + mlp_class(output_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) for _ in range(num_blocks) + ]) + + +class LoRAConformerEncoder(ConformerEncoder): + """Conformer encoder module with lora.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + mlp_bias: bool = True, + conv_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + lora_rank: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_list: Optional[List[str]] = None, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, positionwise_conv_kernel_size, + macaron_style, selfattention_layer_type, + activation_type, use_cnn_module, cnn_module_kernel, + causal, cnn_module_norm, query_bias, key_bias, + value_bias, mlp_bias, conv_bias, + gradient_checkpointing, use_sdpa, mlp_type, + layer_norm_type, norm_eps, n_kv_head, head_dim) + activation = WENET_ACTIVATION_CLASSES[activation_type]() + + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + lora_rank, + lora_alpha, + lora_dropout, + lora_list, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + mlp_bias, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal, conv_bias) + + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + WENET_LORA_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) for _ in range(num_blocks) + ]) diff --git a/wenet/finetune/lora/layers.py b/wenet/finetune/lora/layers.py new file mode 100644 index 000000000..a77a8fc13 --- /dev/null +++ b/wenet/finetune/lora/layers.py @@ -0,0 +1,338 @@ +# Copyright (c) 2021 microsoft +# 2023 Alan (alanfangemail@gmail.com) +# ----------------------------------------------------------------------------- +# Licensed under the MIT License (MIT). See LICENSE in the repo root for +# license information. +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import List + + +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = self.identity + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + def identity(self, x): + return x + + +class Embedding(nn.Embedding, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + merge_weights: bool = True, + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, + merge_weights=merge_weights) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) + self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + + def reset_parameters(self): + nn.Embedding.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + temp = (self.lora_B @ self.lora_A).transpose(0, 1) + self.weight.data -= temp * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + temp = (self.lora_B @ self.lora_A).transpose(0, 1) + self.weight.data += temp * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.r > 0 and not self.merged: + result = nn.Embedding.forward(self, x) + after_A = F.embedding( + x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return nn.Embedding.forward(self, x) + + +class Linear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, + # Set this to True if the layer to replace stores weight like (fan_in, + # fan_out) + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def T(self, w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + temp = self.T(self.lora_B @ self.lora_A) + self.weight.data -= temp * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + temp = self.T(self.lora_B @ self.lora_A) + self.weight.data += temp * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.r > 0 and not self.merged: + result = F.linear(x, self.T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) + @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, self.T(self.weight), bias=self.bias) + + +class MergedLinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + enable_lora: List[bool] = None, + fan_in_fan_out: bool = False, + merge_weights: bool = True, + **kwargs + ): + if enable_lora is None: + enable_lora = [False] + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + assert out_features % len(enable_lora) == 0, \ + 'The length of enable_lora must divide out_features' + self.enable_lora = enable_lora + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * sum(enable_lora), in_features))) + self.lora_B = nn.Parameter( + self.weight.new_zeros((out_features // len(enable_lora) * + sum(enable_lora), r))) + # weights for Conv1D with groups=sum(enable_lora) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros( + (out_features, ), dtype=torch.bool + ).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def zero_pad(self, x): + result = x.new_zeros((len(self.lora_ind), *x.size()[1:])) + result[self.lora_ind] = x + return result + + def T(self, w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + def merge_AB(self): + delta_w = F.conv1d( + self.lora_A.unsqueeze(0), + self.lora_B.unsqueeze(-1), + groups=sum(self.enable_lora) + ).squeeze(0) + return self.T(delta_w) + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + if mode: + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0 and any(self.enable_lora): + self.weight.data -= self.merge_AB() * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0 and any(self.enable_lora): + self.weight.data += self.merge_AB() * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if self.merged: + return F.linear(x, self.T(self.weight), bias=self.bias) + else: + result = F.linear(x, self.T(self.weight), bias=self.bias) + if self.r > 0: + temp = self.T(self.merge_AB().T) + result += self.lora_dropout(x) @ temp * self.scaling + return result + + +class ConvLoRA(nn.Module, LoRALayer): + def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, + lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): + super(ConvLoRA, self).__init__() + self.conv = conv_module(in_channels, out_channels, kernel_size, + **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + merge_weights=merge_weights) + assert isinstance(kernel_size, int) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.conv.weight.new_zeros((r * kernel_size, + in_channels * kernel_size)) + ) + self.lora_B = nn.Parameter( + self.conv.weight.new_zeros( + (out_channels // self.conv.groups * kernel_size, + r * kernel_size)) + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.conv.weight.requires_grad = False + self.reset_parameters() + self.merged = False + + def reset_parameters(self): + self.conv.reset_parameters() + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def train(self, mode=True): + super(ConvLoRA, self).train(mode) + if mode: + if self.merge_weights and self.merged: + if self.r > 0: + # Make sure that the weights are not merged + self.conv.weight.data -= ( + self.lora_B @ self.lora_A + ).view(self.conv.weight.shape) * self.scaling + self.merged = False + else: + if self.merge_weights and not self.merged: + if self.r > 0: + # Merge the weights and mark it + self.conv.weight.data += ( + self.lora_B @ self.lora_A + ).view(self.conv.weight.shape) * self.scaling + self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + return self.conv._conv_forward( + x, + self.conv.weight + ( + self.lora_B @ self.lora_A + ).view(self.conv.weight.shape) * self.scaling, + self.conv.bias + ) + return self.conv(x) + + +class Conv2d(ConvLoRA): + def __init__(self, *args, **kwargs): + super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs) + + +class Conv1d(ConvLoRA): + def __init__(self, *args, **kwargs): + super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs) + + +# Can Extend to other ones like this +class Conv3d(ConvLoRA): + def __init__(self, *args, **kwargs): + super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs) diff --git a/wenet/finetune/lora/utils.py b/wenet/finetune/lora/utils.py new file mode 100644 index 000000000..76604b688 --- /dev/null +++ b/wenet/finetune/lora/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 microsoft +# 2023 Alan (alanfangemail@gmail.com) +# ----------------------------------------------------------------------------- +# Licensed under the MIT License (MIT). See LICENSE in the repo root for +# license information. +# ----------------------------------------------------------------------------- + +import logging +import torch +import torch.nn as nn + +from typing import Dict + +from wenet.finetune.lora.attention import (LoRARelPositionMultiHeadedAttention, + LoRAMultiHeadedAttention) +from wenet.finetune.lora.layers import LoRALayer + + +WENET_LORA_ATTENTION_CLASSES = { + "selfattn": LoRAMultiHeadedAttention, + "rel_selfattn": LoRARelPositionMultiHeadedAttention, +} + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: + logging.info('freezing all params except lora module.') + for n, p in model.named_parameters(): + if 'lora_' not in n: + p.requires_grad = False + if bias == 'none': + return + elif bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for m in model.modules(): + if isinstance(m, LoRALayer) and \ + hasattr(m, 'bias') and \ + m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: + my_state_dict = model.state_dict() + if bias == 'none': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} + elif bias == 'all': + return {k: my_state_dict[k] for k in my_state_dict + if 'lora_' in k or 'bias' in k} + elif bias == 'lora_only': + to_return = {} + for k in my_state_dict: + if 'lora_' in k: + to_return[k] = my_state_dict[k] + bias_name = k.split('lora_')[0] + 'bias' + if bias_name in my_state_dict: + to_return[bias_name] = my_state_dict[bias_name] + return to_return + else: + raise NotImplementedError diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 1255fb556..b278b6dd9 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -14,6 +14,7 @@ import torch +from wenet.finetune.lora.utils import mark_only_lora_as_trainable from wenet.k2.model import K2Model from wenet.paraformer.cif import Cif from wenet.paraformer.layers import SanmDecoder, SanmEncoder @@ -36,6 +37,8 @@ from wenet.whisper.whisper import Whisper from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules +from wenet.finetune.lora.encoder import (LoRATransformerEncoder, + LoRAConformerEncoder) WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, @@ -47,6 +50,8 @@ "dual_transformer": DualTransformerEncoder, "dual_conformer": DualConformerEncoder, 'sanm_encoder': SanmEncoder, + "lora_transformer": LoRATransformerEncoder, + "lora_conformer": LoRAConformerEncoder, } WENET_DECODER_CLASSES = { @@ -100,6 +105,9 @@ def init_model(args, configs): decoder_type = configs.get('decoder', 'bitransformer') ctc_type = configs.get('ctc', 'ctc') + if hasattr(args, 'use_lora') and args.use_lora: + encoder_type = "lora_" + encoder_type + encoder = WENET_ENCODER_CLASSES[encoder_type]( input_dim, global_cmvn=global_cmvn, @@ -168,6 +176,10 @@ def init_model(args, configs): else: infos = {} configs["init_infos"] = infos + + if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: + mark_only_lora_as_trainable(model, bias='lora_only') + print(configs) # Tie emb.weight to decoder.output_layer.weight diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 7e5eed788..06bb71b02 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -113,6 +113,34 @@ def add_dataset_args(parser): return parser +def add_lora_args(parser): + parser.add_argument("--use_lora", + default=False, + type=bool, + help="whether use the lora finetune.") + parser.add_argument("--only_optimize_lora", + default=False, + type=bool, + help="freeze all other paramters and only optimize \ + LoRA-related prameters.") + parser.add_argument("--lora_list", + default=['o', 'q', 'k', 'v'], + help="lora module list.") + parser.add_argument("--lora_rank", + default=8, + type=int, + help="lora rank num.") + parser.add_argument("--lora_alpha", + default=8, + type=int, + help="lora scale param, scale=lora_alpha/lora_rank.") + parser.add_argument("--lora_dropout", + default=0, + type=float, + help="lora dropout param.") + return parser + + def add_ddp_args(parser): parser.add_argument('--ddp.dist_backend', dest='dist_backend', @@ -247,6 +275,12 @@ def check_modify_and_save_config(args, configs, symbol_table): assert ds_configs["gradient_clipping"] == configs['grad_clip'] assert ds_configs["steps_per_print"] == configs['log_interval'] + if args.use_lora: + configs['encoder_conf']['lora_list'] = args.lora_list + configs['encoder_conf']['lora_rank'] = args.lora_rank + configs['encoder_conf']['lora_alpha'] = args.lora_alpha + configs['encoder_conf']['lora_dropout'] = args.lora_dropout + if 'input_dim' not in configs: if 'fbank_conf' in configs['dataset_conf']: input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']