From a3a046e73883793713cbce7a086231339aa96ba2 Mon Sep 17 00:00:00 2001 From: Shuai Wang Date: Thu, 25 Apr 2024 23:30:47 +0800 Subject: [PATCH] [feature] add support for gemini-dfresnet (#291) * [feature] add support for gemini-dfresnet * fix lint errors * add warmup of 6 epochs to config * add warmup of 6 epochs to config * add the results of gemini-df-resnet * update the link for gemini models --- docs/pretrained.md | 22 ++- examples/cnceleb/v3_finetune/README.md | 9 +- examples/voxceleb/v2/README.md | 4 + .../v2/conf/gemini_dfresnet_adam.yaml | 81 ++++++++ .../v2/conf/gemini_dfresnet_sgd_lm.yaml | 91 +++++++++ wespeaker/models/gemini_dfresnet.py | 174 ++++++++++++++++++ wespeaker/models/speaker_model.py | 3 + 7 files changed, 371 insertions(+), 13 deletions(-) create mode 100644 examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml create mode 100644 examples/voxceleb/v2/conf/gemini_dfresnet_sgd_lm.yaml create mode 100644 wespeaker/models/gemini_dfresnet.py diff --git a/docs/pretrained.md b/docs/pretrained.md index 96fbf513..b4894a1b 100644 --- a/docs/pretrained.md +++ b/docs/pretrained.md @@ -39,14 +39,16 @@ in [the voxconverse recipe](https://github.com/wenet-e2e/wespeaker/tree/master/e ## Model List -| Datasets | Languages | Checkpoint (pt) | Runtime Model (onnx) | -|-----------------------------------------------|-----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.zip) | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.onnx) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) | -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet152_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.zip) | [ResNet152_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.onnx) | -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet221_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.zip) | [ResNet221_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.onnx) | -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet293_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.zip) | [ResNet293_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.onnx) | -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [CAM++](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.zip) / [CAM++_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.zip) | [CAM++](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.onnx) / [CAM++_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.onnx) | -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ECAPA512](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip) / [ECAPA512_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512_LM.zip) | [ECAPA512](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.onnx) / [ECAPA512_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512_LM.onnx) | -| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ECAPA1024](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024.zip) / [ECAPA1024_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024_LM.zip) | [ECAPA1024](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024.onnx) / [ECAPA1024_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024_LM.onnx) | -| [CNCeleb](../examples/cnceleb/v2/README.md) | CN | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.zip) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.zip) | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx) | +| Datasets | Languages | Checkpoint (pt) | Runtime Model (onnx) | +|--- |--- |--- |--- | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.zip) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.zip) | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.onnx) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet152_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.zip)| [ResNet152_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet221_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.zip)| [ResNet221_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ResNet293_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.zip)| [ResNet293_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [CAM++](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.zip) / [CAM++_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.zip) | [CAM++](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.onnx) / [CAM++_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ECAPA512](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.zip) / [ECAPA512_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512_LM.zip) | [ECAPA512](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512.onnx) / [ECAPA512_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA512_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ECAPA1024](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024.zip) / [ECAPA1024_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024_LM.zip) | [ECAPA1024](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024.onnx) / [ECAPA1024_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_ECAPA1024_LM.onnx) | +| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [Gemini_DFResnet114_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_gemini_dfresnet114_LM.zip)| [Gemini_DFResnet114_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_gemini_dfresnet114_LM.onnx) | +| [CNCeleb](../examples/cnceleb/v2/README.md) | CN | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.zip) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.zip) | [ResNet34](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx) / [ResNet34_LM](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx) | + diff --git a/examples/cnceleb/v3_finetune/README.md b/examples/cnceleb/v3_finetune/README.md index 0bc7d394..715c4fee 100644 --- a/examples/cnceleb/v3_finetune/README.md +++ b/examples/cnceleb/v3_finetune/README.md @@ -1,11 +1,9 @@ ## Fine-tuning Results Based on DINO -* Setup: fbank80, num_frms200, epoch75 (pretrain), epoch50 (finetune), ArcMargin, aug_prob0.6, speed_perturb (no spec_aug) -* [Pre-trained ECAPA-TDNN checkpoints](https://drive.google.com/drive/folders/1XDIUjnKPrvJE5auBWT5CcE4mqcglCwzq?usp=drive_link): teacher models extracted from `model_75.pt` (please refer to `wespeaker/ssl/bin/average_dino_model.py` for information on the extraction process) +* Setup: fbank80, num_frms200, epoch50 (finetune), ArcMargin, aug_prob0.6, speed_perturb (no spec_aug) * test_trials: CNC-Eval-Avg.lst * These results are obtained by pretraining on different datasets and then finetuning with CNCeleb. - | Model | Params | FLOPs | Pretraining Data | LM | AS-Norm | EER (%) | minDCF (p=0.01) | | :------------------------------ | :-----: | :-----: | :--------------------: | :-: | :-------: | :-------: | :--------------: | | ECAPA_TDNN_GLOB_c1024-ASTP-emb192 | 14.65M | 2.65 G | CNCeleb | × | × | 8.217 | 0.439 | @@ -20,3 +18,8 @@ * 🔥 UPDATE 2024.03: We support finetuning DINO-based self-supervised models, which is trained on the WenetSpeech dataset. Pretrained Paper related to the finetuning results: * [WenetSpeech: A 10000+ Hours Multi-domain Mandarin Corpus for Speech Recognition](https://arxiv.org/pdf/2110.03370.pdf) * [Leveraging In-the-wild Data for Effective Self-supervised Pretraining in Speaker Recognition](https://arxiv.org/pdf/2309.11730.pdf) + +## Resources +* [Pre-trained ECAPA-TDNN checkpoints](https://drive.google.com/drive/folders/1XDIUjnKPrvJE5auBWT5CcE4mqcglCwzq?usp=drive_link) +* [The filtering metadata for wenetspeech](https://drive.google.com/file/d/1UaGuyT1wcKc5g9vRdfIBvLoDRcuOxBlX/view?usp=drive_link) + diff --git a/examples/voxceleb/v2/README.md b/examples/voxceleb/v2/README.md index 35739090..7f8876f1 100644 --- a/examples/voxceleb/v2/README.md +++ b/examples/voxceleb/v2/README.md @@ -47,6 +47,10 @@ | | | | √ | √ | 0.744 | 0.896 | 1.603 | | Res2Net34_Base | 4.68M | 1.77G | × | × | 1.351 | 1.347 | 2.478 | | | | | × | √ | 1.234 | 1.232 | 2.162 | +| Gemini_DFResNet114 | 6.53M | 5.42G | × | × | 0.787 | 0.963 | 1.760 | +| | | | × | √ | 0.707 | 0.889 | 1.546 | +| | | | √ | x | 0.771 | 0.906 | 1.599 | +| | | | √ | √ | 0.638 | 0.839 | 1.427 | ## PLDA results diff --git a/examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml b/examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml new file mode 100644 index 00000000..f493fdde --- /dev/null +++ b/examples/voxceleb/v2/conf/gemini_dfresnet_adam.yaml @@ -0,0 +1,81 @@ +### train configuraton + +exp_dir: exp/Gemini_DF_ResNet114-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-AdamW-epoch165 +gpus: "[0,1]" +num_avg: 2 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 165 +save_epoch_interval: 5 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 128 + num_workers: 8 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 100 + max_num_frames: 800 + resample_rate: 16000 + speed_perturb: True + num_frms: 200 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + fbank_args: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: Gemini_DF_ResNet114 # Gemini_DF_ResNet60 Gemini_DF_ResNet114 GemGemini_DF_ResNet183 Gemini_DF_ResNet237 +model_init: null +model_args: + feat_dim: 80 + embed_dim: 256 + pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False +projection_args: + project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.2 + final_margin: 0.2 + increase_start_epoch: 20 + fix_start_epoch: 40 + update_margin: False + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: AdamW +optimizer_args: + weight_decay: 0.05 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.000125 + final_lr: 0.000001 + warm_up_epoch: 6 + warm_from_zero: False diff --git a/examples/voxceleb/v2/conf/gemini_dfresnet_sgd_lm.yaml b/examples/voxceleb/v2/conf/gemini_dfresnet_sgd_lm.yaml new file mode 100644 index 00000000..3dd33bf7 --- /dev/null +++ b/examples/voxceleb/v2/conf/gemini_dfresnet_sgd_lm.yaml @@ -0,0 +1,91 @@ +### Large margin fine-tuning configuration +# +# The large margin fine-tuning operation is often used in speaker +# verification challenge system to further improve the performance. +# In this fine-tuning stage, large margin and longer segment will +# be used. + +exp_dir: exp/Gemini_DF_ResNet114-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-AdamW-epoch165-LM +gpus: "[0,1]" +num_avg: 1 +enable_amp: False # whether enable automatic mixed precision training +do_lm: True + +seed: 42 +num_epochs: 5 +save_epoch_interval: 1 # save model per epoch +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 32 + num_workers: 8 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 100 + max_num_frames: 800 + resample_rate: 16000 + speed_perturb: True + num_frms: 600 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + fbank_args: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: Gemini_DF_ResNet114 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 +model_init: null +model_args: + feat_dim: 80 + embed_dim: 256 + pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False +projection_args: + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.5 + final_margin: 0.5 + increase_start_epoch: 1 + fix_start_epoch: 1 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 1.0e-4 + final_lr: 2.5e-5 + warm_up_epoch: 1 + warm_from_zero: True + + diff --git a/wespeaker/models/gemini_dfresnet.py b/wespeaker/models/gemini_dfresnet.py new file mode 100644 index 00000000..61918369 --- /dev/null +++ b/wespeaker/models/gemini_dfresnet.py @@ -0,0 +1,174 @@ +# Copyright (c) 2024 Shuai Wang (wsstriving@gmail.com) +# 2024 Tianchi Liu (tianchi_liu@u.nus.edu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +'''The implementation of Gemini-DF-ResNet. + +Reference: +[1] Liu, Tianchi, et al. "Golden Gemini is All You Need: Finding the + Sweet Spots for Speaker Verification." arXiv:2312.03620 (2023). +[2] Liu, Bei, et al. "DF-ResNet: Boosting Speaker Verification Performance + with Depth-First Design." INTERSPEECH. 2022. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +import wespeaker.models.pooling_layers as pooling_layers + + +class Inverted_Bottleneck(nn.Module): + def __init__(self, dim): + super(Inverted_Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(dim, 4 * dim, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(4 * dim) + self.conv2 = nn.Conv2d(4 * dim, 4 * dim, + kernel_size=3, padding=1, groups=4 * dim, + bias=False) + self.bn2 = nn.BatchNorm2d(4 * dim) + self.conv3 = nn.Conv2d(4 * dim, dim, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(dim) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += x + out = F.relu(out) + return out + + +class Gemini_DF_ResNet(nn.Module): + # DF_ResNet with T14c stride strategy of Golden Gemini + def __init__(self, + depths, + dims, + feat_dim=40, + embed_dim=128, + pooling_func='TSTP', + two_emb_layer=False): + super(Gemini_DF_ResNet, self).__init__() + self.feat_dim = feat_dim + self.embed_dim = embed_dim + self.stats_dim = int(feat_dim / 8 / 2) * dims[-1] + self.two_emb_layer = two_emb_layer + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv2d(1, dims[0], kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(dims[0]), + nn.ReLU() + ) + self.downsample_layers.append(stem) + + stride_f = [2, 2, 2, 2] + stride_t = [1, 2, 1, 1] + + for i in range(4): + downsample_layer = nn.Sequential( + nn.Conv2d( + dims[i], dims[i + 1], kernel_size=3, + stride=(stride_f[i], stride_t[i]), + padding=1, bias=False), + nn.BatchNorm2d(dims[i + 1]) + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() + for i in range(4): + stage = nn.Sequential( + *[Inverted_Bottleneck(dim=dims[i + 1]) for _ in range(depths[i])] + ) + self.stages.append(stage) + + self.pool = getattr(pooling_layers, + pooling_func)(in_dim=self.stats_dim) + self.pool_out_dim = self.pool.get_out_dim() + self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) + self.seg_2 = nn.Linear(embed_dim, embed_dim) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = self.downsample_layers[0](x) + out = self.downsample_layers[1](out) + out = self.stages[0](out) + out = self.downsample_layers[2](out) + out = self.stages[1](out) + out = self.downsample_layers[3](out) + out = self.stages[2](out) + out = self.downsample_layers[4](out) + out = self.stages[3](out) + + stats = self.pool(out) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_a, embed_b + else: + return torch.tensor(0.0), embed_a + + +# following models do include separate downsmapling layers into layer counting +def Gemini_DF_ResNet60(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 3, 9, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +def Gemini_DF_ResNet114(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 3, 27, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +def Gemini_DF_ResNet183(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 8, 45, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +def Gemini_DF_ResNet237(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=False): + return Gemini_DF_ResNet(depths=[3, 8, 63, 3], + dims=[32, 32, 64, 128, 256], + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + +if __name__ == '__main__': + x = torch.zeros(1, 200, 80) + model = Gemini_DF_ResNet114(80, 256, 'TSTP') + model.eval() + out = model(x) + print(out[-1].size()) + + num_params = sum(p.numel() for p in model.parameters()) + print("{} M".format(num_params / 1e6)) diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index 70a6bc1d..8475f1ae 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -18,6 +18,7 @@ import wespeaker.models.repvgg as repvgg import wespeaker.models.campplus as campplus import wespeaker.models.eres2net as eres2net +import wespeaker.models.gemini_dfresnet as gemini import wespeaker.models.res2net as res2net @@ -36,6 +37,8 @@ def get_speaker_model(model_name: str): return getattr(eres2net, model_name) elif model_name.startswith("Res2Net"): return getattr(res2net, model_name) + elif model_name.startswith("Gemini"): + return getattr(gemini, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1)