Skip to content

Commit

Permalink
update for training
Browse files Browse the repository at this point in the history
  • Loading branch information
lukhy committed Jun 20, 2019
1 parent 82371b5 commit b5209eb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions models/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, conv, p):
super().__init__()
self.conv = conv
nn.init.kaiming_normal_(self.conv.weight)
weight_norm(self.conv)
self.conv = weight_norm(self.conv)
self.act = nn.GLU(1)
self.dropout = nn.Dropout(p, inplace=True)

Expand Down Expand Up @@ -50,7 +50,9 @@ def forward(self, x, lens): # -> B * V * T
x = self.cnn(x)
for module in self.modules():
if type(module) == nn.modules.Conv1d:
lens = (lens - module.kernel_size[0]) / module.stride[0] + 1
lens = (
lens - module.kernel_size[0] + 2 * module.padding[0]
) // module.stride[0] + 1
return x, lens

def predict(self, path):
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def train(
writer.add_scalar("loss/epoch", epoch_loss, epoch)
writer.add_scalar("cer/epoch", cer, epoch)
print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
torch.save(model, "models/v5/casr_v5_1_epoch_{}.pth".format(epoch))
torch.save(model, "pretrained/model_{}.pth".format(epoch))


def get_lr(optimizer):
Expand Down

0 comments on commit b5209eb

Please sign in to comment.