Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model] MQMHA+arc_margin_intertopk_subcenter #115

Merged
merged 1 commit into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/voxceleb/v2/conf/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: "TSTP"
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
projection_args:
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

Expand Down
4 changes: 2 additions & 2 deletions examples/voxceleb/v2/conf/resnet_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: "TSTP"
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
projection_args:
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

Expand Down
1 change: 1 addition & 0 deletions wespeaker/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def train(config='conf/config.yaml', **kwargs):
if configs.get('do_lm', False):
logger.info('No speed perturb while doing large margin fine-tuning')
configs['dataset_args']['speed_perturb'] = False
configs['projection_args']['do_lm'] = configs.get('do_lm', False)
projection = get_projection(configs['projection_args'])
model.add_module("projection", projection)
if rank == 0:
Expand Down
1 change: 1 addition & 0 deletions wespeaker/bin/train_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def train(config='conf/config.yaml', **kwargs):
if configs['feature_args']['raw_wav'] and configs['dataset_args']['speed_perturb']:
# diff speed is regarded as diff spk
configs['projection_args']['num_class'] *= 3
configs['projection_args']['do_lm'] = config.get('do_lm', False)
projection = get_projection(configs['projection_args'])
model.add_module("projection", projection)
if rank == 0:
Expand Down
8 changes: 4 additions & 4 deletions wespeaker/models/ecapa_tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def __init__(self,
cat_channels = channels * 3
out_channels = 512 * 3
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=out_channels, global_context_att=global_context_att)
self.bn = nn.BatchNorm1d(out_channels * self.n_stats)
self.linear = nn.Linear(out_channels * self.n_stats, embed_dim)
self.pool_out_dim = self.pool.get_out_dim()
self.bn = nn.BatchNorm1d(self.pool_out_dim)
self.linear = nn.Linear(self.pool_out_dim, embed_dim)

def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
Expand Down Expand Up @@ -247,7 +247,7 @@ def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func='ASTP'):
x = torch.zeros(10, 200, 80)
model = ECAPA_TDNN_GLOB_c512(feat_dim=80,
embed_dim=192,
pooling_func='ASTP')
pooling_func='MQMHASTP')
model.eval()
out = model(x)
print(out.shape)
Expand Down
185 changes: 179 additions & 6 deletions wespeaker/models/pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Pooling functions to aggregate frame-level deep features
into segment-level speaker embeddings
Expand All @@ -22,62 +21,82 @@

import torch
import torch.nn as nn
import torch.nn.functional as F


class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, **kwargs):

def __init__(self, in_dim=0, **kwargs):
super(TAP, self).__init__()
self.in_dim = in_dim

def forward(self, x):
pooling_mean = x.mean(dim=-1)
# To be compatable with 2D input
pooling_mean = pooling_mean.flatten(start_dim=1)
return pooling_mean

def get_out_dim(self):
self.out_dim = self.in_dim
return self.out_dim


class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, **kwargs):

def __init__(self, in_dim=0, **kwargs):
super(TSDP, self).__init__()
self.in_dim = in_dim

def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std

def get_out_dim(self):
self.out_dim = self.in_dim
return self.out_dim


class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, **kwargs):

def __init__(self, in_dim=0, **kwargs):
super(TSTP, self).__init__()
self.in_dim = in_dim

def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)

stats = torch.cat((pooling_mean, pooling_std), 1)
return stats

def get_out_dim(self):
self.out_dim = self.in_dim * 2
return self.out_dim


class ASTP(nn.Module):
""" Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):

def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs):
super(ASTP, self).__init__()
self.in_dim = in_dim
self.global_context_att = global_context_att

# Use Conv1d with stride == 1 rather than Linear, then we don't
Expand Down Expand Up @@ -119,3 +138,157 @@ def forward(self, x):
var = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-10))
return torch.cat([mean, std], dim=1)

def get_out_dim(self):
self.out_dim = 2 * self.in_dim
return self.out_dim


class MHASTP(torch.nn.Module):
""" Multi head attentive statistics pooling
Reference:
Self Multi-Head Attention for Speaker Recognition
https://arxiv.org/pdf/1906.09890.pdf
"""

def __init__(self,
in_dim,
layer_num=2,
head_num=2,
d_s=1,
bottleneck_dim=64,
**kwargs):
super(MHASTP, self).__init__()
assert (in_dim % head_num
) == 0 # make sure that head num can be divided by input_dim
self.in_dim = in_dim
self.head_num = head_num
d_model = int(in_dim / head_num)
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
if d_s > 1:
d_s = d_model
else:
d_s = 1
self.d_s = d_s
channel_dims[0], channel_dims[-1] = d_model, d_s
heads_att_trans = []
for i in range(self.head_num):
att_trans = nn.Sequential()
for i in range(layer_num - 1):
att_trans.add_module(
'att_' + str(i),
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
att_trans.add_module('tanh' + str(i), nn.Tanh())
att_trans.add_module(
'att_' + str(layer_num - 1),
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
1, 1))
heads_att_trans.append(att_trans)
self.heads_att_trans = nn.ModuleList(heads_att_trans)

def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x F x T
input = input.reshape(input.shape[0],
input.shape[1] * input.shape[2],
input.shape[3])
assert len(input.shape) == 3
bs, f_dim, t_dim = input.shape
chunks = torch.chunk(input, self.head_num, 1)
# split
chunks_out = []
# for i in range(self.head_num):
# att_score = self.heads_att_trans[i](chunks[i])
for i, layer in enumerate(self.heads_att_trans):
att_score = layer(chunks[i])
alpha = F.softmax(att_score, dim=-1)
mean = torch.sum(alpha * chunks[i], dim=2)
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-10))
chunks_out.append(torch.cat((mean, std), dim=1))
out = torch.cat(chunks_out, dim=1)
return out

def get_out_dim(self):
self.out_dim = 2 * self.in_dim
return self.out_dim


class MQMHASTP(torch.nn.Module):
""" An attentive pooling
Reference:
multi query multi head attentive statistics pooling
https://arxiv.org/pdf/2110.05042.pdf
Args:
in_dim: the feature dimension of input
layer_num: the number of layer in the pooling layer
query_num: the number of querys
head_num: the number of heads
bottleneck_dim: the bottleneck dimension
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
https://arxiv.org/pdf/1906.09890.pdf
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
https://arxiv.org/pdf/1803.10963.pdf
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
"""

def __init__(self,
in_dim,
layer_num=2,
query_num=2,
head_num=8,
d_s=2,
bottleneck_dim=64,
**kwargs):
super(MQMHASTP, self).__init__()
self.n_query = nn.ModuleList([
MHASTP(in_dim,
layer_num=layer_num,
head_num=head_num,
d_s=d_s,
bottleneck_dim=bottleneck_dim) for i in range(query_num)
])
self.query_num = query_num
self.in_dim = in_dim

def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x F x T
input = input.reshape(input.shape[0],
input.shape[1] * input.shape[2],
input.shape[3])
assert len(input.shape) == 3
res = []
for i, layer in enumerate(self.n_query):
res.append(layer(input))
out = torch.cat(res, dim=-1)
return out

def get_out_dim(self):
self.out_dim = self.in_dim * 2 * self.query_num
return self.out_dim


if __name__ == '__main__':
data = torch.randn(16, 512, 10, 35)
# model = StatisticsPooling()
model = MQMHASTP(512 * 10)
model = MHASTP(512 * 10)
model = MQMHASTP(512 * 10, context=False)
print(model)

out = model(data)
print(out.shape)
print(model.get_out_dim())
Loading