From b5209eb702f7080f6269ab20816531d805e2ed52 Mon Sep 17 00:00:00 2001 From: lukhy Date: Thu, 20 Jun 2019 21:03:49 +0800 Subject: [PATCH] update for training --- models/conv.py | 6 ++++-- train.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/models/conv.py b/models/conv.py index ed90888..64f7147 100644 --- a/models/conv.py +++ b/models/conv.py @@ -10,7 +10,7 @@ def __init__(self, conv, p): super().__init__() self.conv = conv nn.init.kaiming_normal_(self.conv.weight) - weight_norm(self.conv) + self.conv = weight_norm(self.conv) self.act = nn.GLU(1) self.dropout = nn.Dropout(p, inplace=True) @@ -50,7 +50,9 @@ def forward(self, x, lens): # -> B * V * T x = self.cnn(x) for module in self.modules(): if type(module) == nn.modules.Conv1d: - lens = (lens - module.kernel_size[0]) / module.stride[0] + 1 + lens = ( + lens - module.kernel_size[0] + 2 * module.padding[0] + ) // module.stride[0] + 1 return x, lens def predict(self, path): diff --git a/train.py b/train.py index 13f576a..3d5f474 100644 --- a/train.py +++ b/train.py @@ -75,7 +75,7 @@ def train( writer.add_scalar("loss/epoch", epoch_loss, epoch) writer.add_scalar("cer/epoch", cer, epoch) print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer)) - torch.save(model, "models/v5/casr_v5_1_epoch_{}.pth".format(epoch)) + torch.save(model, "pretrained/model_{}.pth".format(epoch)) def get_lr(optimizer):