Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INF loss cause by ScaledLinear #534

Closed
drawfish opened this issue Aug 17, 2022 · 8 comments
Closed

INF loss cause by ScaledLinear #534

drawfish opened this issue Aug 17, 2022 · 8 comments

Comments

@drawfish
Copy link

After 30 epochs of model training, inf loss appeared in a certain batch of training on my own dataset.
My train script is modified from egs/librispecch/pruned_transducer_stateless2.
The problem seems to be on the weight parameters of the simple_lm_proj layer which is ScaledLinear type.
The reason for this exception is that the real weight of the layer has reached the maximum representation range of the float32 point (Convert to 64 bit floating point can solve it).
Maybe need to reconsider the construction of ScaledLinear module.

Case for reproduct:

model['model']['simple_lm_proj.weight'][750:751,] = 
tensor([[-5.1551e-02, -1.2240e+00, -7.3813e-02, -3.3730e-02, -1.1170e-01,
         -1.6473e-02, -9.9509e-01, -1.1051e-01, -4.0209e-02, -1.1383e-01,
         -8.0831e-02, -4.1723e-02, -8.4578e-02, -6.2695e-02, -3.7586e-02,
         -5.3306e-02, -2.7846e-01, -4.9898e-02, -4.4774e-02, -6.7214e-02,
         -1.6483e-01, -2.1416e-01, -6.1130e-02, -2.9127e-01, -3.5063e-02,
         -1.0238e-01, -1.7096e-02, -1.1050e-01, -4.8205e-02, -1.6165e-01,
         -5.6212e-02, -3.8316e-01, -6.5313e-02, -8.4273e-02, -3.5621e-02,
          6.7110e-03, -7.1563e-02, -1.2871e-01, -1.0845e-02, -6.4662e-02,
         -8.9876e-02, -9.2465e-01, -2.9431e-02, -3.5038e-02, -1.7246e-01,
         -6.0021e-02, -5.8539e-02, -3.6222e-02, -1.9723e-01, -2.1020e-02,
         -3.7115e-01, -7.6494e-02, -4.6610e-02, -7.2946e-02, -1.7218e-01,
         -8.1635e-02, -6.5075e-02, -5.5851e-02, -5.7309e-02, -5.2118e-02,
         -2.9748e-01, -8.4485e-02, -2.5913e-01, -3.1858e-02, -5.5928e-02,
         -8.7669e-03, -6.1204e-02, -2.6659e-01, -3.9655e-02, -7.5340e-01,
         -9.5071e-01, -1.2468e-01, -2.5305e-01, -3.7410e-02, -4.6897e-02,
         -5.0210e-02, -1.7230e-01, -7.9918e-02, -7.9871e-02, -4.5755e-02,
         -5.3539e-02, -5.2717e-02, -7.1866e-02, -6.1196e-01, -1.0631e-01,
         -3.0501e-01, -6.7029e-02, -4.0855e-02, -8.2058e-02, -8.0724e-02,
         -6.9097e-02, -4.5394e-02, -6.6132e-02, -5.6880e-02, -2.2283e-02,
         -1.5411e-01, -1.2072e-01, -5.0433e-02, -9.7660e-02, -9.3472e-02,
         -7.4087e-02, -7.5583e-02, -2.3613e-01, -1.5705e-02, -1.2086e-01,
         -3.4739e-02, -6.3825e-01, -5.2309e-02, -1.0070e-01, -1.0800e-01,
         -8.1152e-02, -6.3586e-01, -2.1534e-01, -5.0373e-02, -6.6397e-02,
         -1.4448e+00, -4.7198e-01, -5.6727e-02, -5.3544e-02, -4.8757e-02,
         -6.4949e-02, -3.0841e-02, -7.2466e-02,  2.1810e-03, -3.8469e-02,
         -4.7542e-01, -5.5890e-02, -5.7630e-02, -4.2648e-02, -1.9552e-03,
         -5.8567e-02, -3.3427e-02, -1.9647e-02, -4.9359e-02, -3.4640e-01,
         -1.3546e-01, -1.3460e-01, -3.3435e-02, -1.1019e-01, -4.2527e-02,
         -6.3337e-02, -7.0136e-02, -1.8391e-02, -7.6174e-02, -5.6761e-02,
         -7.1759e-02, -5.1425e-02, -1.0638e-01, -3.8656e-02, -4.6370e-02,
         -6.5202e-02, -6.8784e-02, -4.5518e-02, -4.7009e-02, -4.5324e-02,
         -5.9385e-02, -5.0962e-02,  3.7499e-03, -2.4719e-02, -7.3078e-03,
         -5.7817e-02, -8.4220e-02, -1.7582e-01, -5.1043e-02, -5.0012e-02,
         -4.8936e-02, -1.4196e-02, -8.0739e-02, -1.3479e-01, -1.1806e-01,
         -1.1344e-01, -7.7569e-02, -5.0056e-01, -1.5426e-01, -4.9186e-02,
         -7.7820e-02, -2.1644e-01, -9.5791e-02, -3.1976e-02, -3.9607e-02,
         -2.5053e-01, -3.8545e-02, -1.1121e-01, -4.4510e-02, -7.6218e-01,
         -5.5244e-02, -7.3911e-01, -9.3564e-02, -4.9999e-02, -4.2123e-02,
         -1.2258e-01, -3.6594e-02, -1.0302e-01, -4.5226e-02, -1.1090e-01,
         -4.8002e-02, -9.1842e-02, -5.2101e-02, -5.2603e-02, -2.1610e-01,
         -8.2486e-02, -8.6079e-02, -1.1752e-01, -6.5852e-02, -5.8898e-02,
         -6.4289e-02, -3.8480e-02, -1.2140e-01, -3.7018e-02, -4.2722e-01,
         -1.4040e-01, -4.0513e-02, -4.3174e-02, -5.6327e-02,  7.7248e-03,
         -2.5351e-02, -1.2188e-01, -5.9194e-02, -1.4487e-01, -9.8768e-02,
         -1.8639e-02, -7.2761e-02, -1.0941e-01, -1.0838e-01, -2.6792e-02,
         -3.7384e-01, -3.9063e-02, -4.7863e-02, -4.1156e-02, -1.3061e-01,
         -3.7968e-02, -5.1367e-02, -4.2187e-02, -2.5586e-01, -2.3548e-01,
         -1.1322e-01, -9.1195e-02, -6.1605e-02, -3.8107e-02, -6.6706e-01,
         -1.5474e-01, -6.3752e-02, -6.7706e-02, -3.5113e-02, -5.7614e-02,
         -7.8612e-02, -7.6630e-02, -5.3708e-04, -3.1338e-02, -5.2683e-02,
         -5.4282e-02, -4.0769e-02, -1.1259e-01, -1.2182e-01, -2.6559e-02,
         -4.6250e-02, -3.0973e-01, -4.5191e-02, -6.7004e-02, -3.7445e-02,
         -6.6337e-02, -2.0377e-01, -5.0093e-02, -2.0332e-01, -5.3287e-02,
         -9.4977e-02, -6.3003e-02, -1.0838e-01, -6.2606e-02, -2.8038e-01,
         -7.0678e-02, -2.4482e-01, -9.8526e-02, -5.2309e-02, -1.1569e-01,
         -9.7541e-02, -4.9262e-02, -5.9300e-02, -5.6827e-02, -4.9262e-01,
         -2.7272e-02, -3.2982e-02, -5.3130e-02, -4.5089e-02, -3.1125e-02,
         -5.7751e-01, -1.3850e-01, -3.8865e-01, -5.2711e-02, -4.8536e-02,
         -1.2963e-01, -1.7405e-01, -5.1960e-02, -8.6757e-02, -4.5462e-02,
         -4.7745e-02, -3.4743e-01, -1.7050e-01, -5.3401e-02, -3.4623e-01,
         -8.0635e-02, -4.2998e-02, -1.5925e-01, -4.3076e-02, -2.7254e-02,
         -5.6783e-02, -5.5663e-02, -7.1640e-02, -1.0117e-01, -7.0240e-02,
         -7.4079e-02, -2.8625e-01, -9.4638e-02, -4.6530e-02, -5.9577e-02,
         -3.2659e-02, -2.8791e-02, -5.0677e-02, -4.7817e-02, -7.1293e-01,
         -4.2731e-02, -1.6784e-02, -2.0738e-01, -5.0554e-02, -6.8905e-02,
         -1.8411e-01, -6.9809e-02, -6.8151e-02, -5.8461e-02, -2.5264e-01,
         -6.7830e-01, -3.7961e-02, -9.1771e-02, -6.7530e-02, -5.2568e-02,
         -9.6570e-02, -4.6387e-02, -1.9618e-02, -1.3361e-01, -1.4486e-01,
         -4.8101e-02, -1.1394e+00, -5.3716e-02, -7.2697e-01, -6.0544e-02,
         -1.3308e+00, -5.3715e-02, -5.2132e-02, -4.4533e-02, -6.4218e-01,
         -1.2380e-01, -2.1958e-02, -1.2413e-01, -4.1048e-02, -8.1298e-01,
         -5.9336e-02, -6.9936e-02, -9.5837e-03, -4.5090e-02, -2.4049e-02,
         -5.5434e-02, -6.7948e-02, -2.5848e-01, -2.6473e-01, -3.6305e-02,
         -7.9389e-02, -4.1287e-02, -4.1839e-02, -3.4426e-02,  3.4784e-02,
         -8.7626e-02, -2.9819e-01, -3.5634e-02,  7.5635e-03, -2.8337e-02,
         -2.8441e-02, -1.7008e-02, -1.4265e-01, -1.3546e-01, -6.8186e-02,
         -4.2162e-01, -1.4079e-01, -6.0626e-02, -1.0303e-01, -6.0701e-02,
         -2.8636e-01, -1.4830e-01, -4.7742e-02, -2.5351e-01, -3.0312e-02,
         -1.4263e-01, -8.8514e-02, -4.4926e-02, -3.8045e-02, -2.9936e-01,
         -9.4252e-01, -3.2635e-02, -8.2300e-02, -9.0060e-02, -4.4790e-02,
         -5.8432e-02, -8.9879e-02, -2.6613e-01, -5.6486e-02, -7.5909e-02,
         -9.8705e-02, -9.7507e-02, -6.7886e-02, -6.8242e-02, -6.2700e-02,
         -7.2419e-02, -4.0690e-02, -1.2547e-01, -4.2858e-02, -9.0597e-02,
         -3.1981e-02, -6.4046e-02, -9.9686e-02, -8.9982e-02, -6.2587e-02,
         -3.3651e-01, -2.2716e-01, -7.5307e-02, -8.7806e-02, -7.5388e-02,
         -6.0395e-02, -9.4728e-01, -2.2904e-01, -4.4637e-02, -3.4151e-02,
         -1.0278e-01, -5.7214e-02, -1.6894e-01, -5.3392e-02, -1.1832e-01,
         -4.8231e-02, -4.3478e-02, -5.9042e-02, -1.4215e-01, -3.6114e-02,
         -8.0854e-02, -1.0891e-01, -3.5372e-02, -4.9119e-02, -5.0923e-02,
         -7.6433e-02, -9.0131e-02, -1.2763e-01, -8.9044e-02, -1.7024e-02,
         -6.8019e-02, -1.2509e-01, -5.0442e-01, -9.3473e-02, -4.4497e-02,
         -1.1393e-02, -7.5144e-02, -1.5948e-02, -6.1715e-02, -2.3791e-01,
         -5.9066e-02, -3.9103e-01, -1.4818e-01, -3.0492e-01, -1.0114e-01,
         -5.4421e-02, -2.3949e-02, -1.5516e-02, -6.8998e-02, -3.9734e-02,
         -5.3857e-02, -9.8981e-02, -3.9295e-02, -2.7041e-02, -3.6212e-03,
         -1.0315e-01, -5.9921e-02, -6.4933e-02, -9.0143e-01, -3.0797e-02,
         -1.1285e+00, -3.7405e-02, -5.9726e-02, -5.4390e-02, -1.3348e-01,
         -7.2267e-02, -8.7486e-01, -4.9220e-02, -1.4692e-01, -4.1404e-02,
         -2.9136e-02, -3.9668e-02, -4.8674e-02, -5.2437e-02, -5.1657e-02,
         -1.1372e-01, -1.9949e-01, -8.3018e-02, -4.4488e-02, -4.8662e-02,
         -7.3401e-02, -7.4476e-02, -5.9747e-02, -7.0884e-02, -4.0597e-01,
         -1.1629e-01, -5.4890e-02, -1.4877e-01, -8.4538e-02, -9.1792e-02,
         -1.3569e-01, -5.3069e-02]], device='cuda:0')
model['model']['simple_lm_proj.weight_scale'] = tensor(88.3546, device='cuda:0')
real_weight_forward = model['model']['simple_lm_proj.weight'][750:751,] * model['model']['simple_lm_proj.weight_scale'].exp()
input = torch.zeros((1,1,512),device='cuda:0')
output = torch.matmul(input , real_weight_forward.transpose(0,1))
# output is:
tensor([[[nan]]], device='cuda:0')

# Covert to float64
real_weight_forward_float64 = model['model']['simple_lm_proj.weight'][750:751,].type(torch.float64) * model['model']['simple_lm_proj.weight_scale'].exp().type(torch.float64)
input = torch.zeros((1,1,512),device='cuda:0', dtype=torch.float64)
output = torch.matmul(input , real_weight_forward_float64.transpose(0,1))
# output is:
tensor([[[0.]]], device='cuda:0', dtype=torch.float64)
@csukuangfj
Copy link
Collaborator

Could you use the changes from #531, which have been merged into the master, to
print out the batch that causes inf loss?

@pkufool
Copy link
Collaborator

pkufool commented Aug 18, 2022

What optimizer did you use? Did you use Eve?

@drawfish
Copy link
Author

What optimizer did you use? Did you use Eve?

Yes, Eve optimizer and Eden scheduler.

@drawfish
Copy link
Author

Could you use the changes from #531, which have been merged into the master, to print out the batch that causes inf loss?

Yes, I have learned about this PR on wechat group. But this may be a new problem cause by the decoder and joiner part of the
reworked transducer model.
By printing the weight scaling factor of the model, it seems that the factor of decoder and joiner are an order of magnitude larger than the encoder.

encoder.encoder.layers.11.feed_forward.4.bias_scale tensor(-1.3940, device='cuda:0')
encoder.encoder.layers.11.feed_forward_macaron.0.weight_scale tensor(-0.2941, device='cuda:0')
encoder.encoder.layers.11.feed_forward_macaron.0.bias_scale tensor(3.0587, device='cuda:0')
encoder.encoder.layers.11.feed_forward_macaron.4.weight_scale tensor(-1.1440, device='cuda:0')
encoder.encoder.layers.11.feed_forward_macaron.4.bias_scale tensor(-1.4464, device='cuda:0')
encoder.encoder.layers.11.conv_module.pointwise_conv1.weight_scale tensor(-0.9279, device='cuda:0')
encoder.encoder.layers.11.conv_module.pointwise_conv1.bias_scale tensor(2.1177, device='cuda:0')
encoder.encoder.layers.11.conv_module.depthwise_conv.weight_scale tensor(0.9658, device='cuda:0')
encoder.encoder.layers.11.conv_module.depthwise_conv.bias_scale tensor(1.5869, device='cuda:0')
encoder.encoder.layers.11.conv_module.pointwise_conv2.weight_scale tensor(-0.7548, device='cuda:0')
encoder.encoder.layers.11.conv_module.pointwise_conv2.bias_scale tensor(-2.4712, device='cuda:0')
decoder.conv.weight_scale tensor(-40.2921, device='cuda:0')
joiner.encoder_proj.weight_scale tensor(-2.1799, device='cuda:0')
joiner.encoder_proj.bias_scale tensor(-0.4448, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(85.6988, device='cuda:0')
joiner.decoder_proj.bias_scale tensor(-0.4448, device='cuda:0')
joiner.output_linear.weight_scale tensor(3.8689, device='cuda:0')
joiner.output_linear.bias_scale tensor(-6.0607, device='cuda:0')
simple_am_proj.weight_scale tensor(2.1853, device='cuda:0')
simple_am_proj.bias_scale tensor(4.2252, device='cuda:0')
simple_lm_proj.weight_scale tensor(88.3546, device='cuda:0')
simple_lm_proj.bias_scale tensor(5.1967, device='cuda:0')

where:

decoder.conv.weight_scale tensor(-40.2921, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(85.6988, device='cuda:0')
simple_lm_proj.weight_scale tensor(88.3546, device='cuda:0')

@drawfish
Copy link
Author

drawfish commented Aug 18, 2022

More debug info:
As the training progresses, the weight scaling factor gradually increases:

------------- epoch 0 -------------
decoder.conv.weight_scale tensor(-7.3865, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(18.4063, device='cuda:0')
simple_lm_proj.weight_scale tensor(19.7458, device='cuda:0')
------------- epoch 1 -------------
decoder.conv.weight_scale tensor(-11.3794, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(26.6442, device='cuda:0')
simple_lm_proj.weight_scale tensor(28.1604, device='cuda:0')
------------- epoch 2 -------------
decoder.conv.weight_scale tensor(-14.4583, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(32.8381, device='cuda:0')
simple_lm_proj.weight_scale tensor(34.4227, device='cuda:0')
------------- epoch 3 -------------
decoder.conv.weight_scale tensor(-17.0543, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(38.1340, device='cuda:0')
simple_lm_proj.weight_scale tensor(39.7966, device='cuda:0')
------------- epoch 4 -------------
decoder.conv.weight_scale tensor(-19.2274, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(42.5399, device='cuda:0')
simple_lm_proj.weight_scale tensor(44.2683, device='cuda:0')
------------- epoch 5 -------------
decoder.conv.weight_scale tensor(-21.2067, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(46.6369, device='cuda:0')
simple_lm_proj.weight_scale tensor(48.4233, device='cuda:0')
------------- epoch 6 -------------
decoder.conv.weight_scale tensor(-22.9498, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(50.1861, device='cuda:0')
simple_lm_proj.weight_scale tensor(52.0434, device='cuda:0')
------------- epoch 7 -------------
decoder.conv.weight_scale tensor(-24.5115, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(53.3605, device='cuda:0')
simple_lm_proj.weight_scale tensor(55.3102, device='cuda:0')
------------- epoch 8 -------------
decoder.conv.weight_scale tensor(-25.8015, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(55.8739, device='cuda:0')
simple_lm_proj.weight_scale tensor(57.9971, device='cuda:0')
------------- epoch 9 -------------
decoder.conv.weight_scale tensor(-26.9941, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(58.3102, device='cuda:0')
simple_lm_proj.weight_scale tensor(60.4677, device='cuda:0')
------------- epoch 10 -------------
decoder.conv.weight_scale tensor(-28.0990, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(60.5648, device='cuda:0')
simple_lm_proj.weight_scale tensor(62.7376, device='cuda:0')
------------- epoch 11 -------------
decoder.conv.weight_scale tensor(-29.0849, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(62.5643, device='cuda:0')
simple_lm_proj.weight_scale tensor(64.7944, device='cuda:0')
------------- epoch 12 -------------
decoder.conv.weight_scale tensor(-29.9371, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(64.2516, device='cuda:0')
simple_lm_proj.weight_scale tensor(66.5009, device='cuda:0')
------------- epoch 13 -------------
decoder.conv.weight_scale tensor(-30.7538, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(65.9434, device='cuda:0')
simple_lm_proj.weight_scale tensor(68.2259, device='cuda:0')
------------- epoch 14 -------------
decoder.conv.weight_scale tensor(-31.5200, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(67.5060, device='cuda:0')
simple_lm_proj.weight_scale tensor(69.8165, device='cuda:0')
------------- epoch 15 -------------
decoder.conv.weight_scale tensor(-32.2282, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(68.9826, device='cuda:0')
simple_lm_proj.weight_scale tensor(71.2871, device='cuda:0')
------------- epoch 16 -------------
decoder.conv.weight_scale tensor(-32.8845, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(70.3055, device='cuda:0')
simple_lm_proj.weight_scale tensor(72.6047, device='cuda:0')
------------- epoch 17 -------------
decoder.conv.weight_scale tensor(-33.5106, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(71.6128, device='cuda:0')
simple_lm_proj.weight_scale tensor(73.9309, device='cuda:0')
------------- epoch 18 -------------
decoder.conv.weight_scale tensor(-34.1024, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(72.8120, device='cuda:0')
simple_lm_proj.weight_scale tensor(75.1720, device='cuda:0')
------------- epoch 19 -------------
decoder.conv.weight_scale tensor(-34.6512, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(73.9153, device='cuda:0')
simple_lm_proj.weight_scale tensor(76.3286, device='cuda:0')
------------- epoch 20 -------------
decoder.conv.weight_scale tensor(-35.1654, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(74.9623, device='cuda:0')
simple_lm_proj.weight_scale tensor(77.4025, device='cuda:0')
------------- epoch 21 -------------
decoder.conv.weight_scale tensor(-35.6816, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(76.0416, device='cuda:0')
simple_lm_proj.weight_scale tensor(78.4976, device='cuda:0')
------------- epoch 22 -------------
decoder.conv.weight_scale tensor(-36.1886, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(77.0714, device='cuda:0')
simple_lm_proj.weight_scale tensor(79.5324, device='cuda:0')
------------- epoch 23 -------------
decoder.conv.weight_scale tensor(-36.6584, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(77.9829, device='cuda:0')
simple_lm_proj.weight_scale tensor(80.4837, device='cuda:0')
------------- epoch 24 -------------
decoder.conv.weight_scale tensor(-37.1132, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(78.9565, device='cuda:0')
simple_lm_proj.weight_scale tensor(81.4988, device='cuda:0')
------------- epoch 25 -------------
decoder.conv.weight_scale tensor(-37.5599, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(79.9053, device='cuda:0')
simple_lm_proj.weight_scale tensor(82.4435, device='cuda:0')
------------- epoch 26 -------------
decoder.conv.weight_scale tensor(-38.0059, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(80.8287, device='cuda:0')
simple_lm_proj.weight_scale tensor(83.3836, device='cuda:0')
------------- epoch 27 -------------
decoder.conv.weight_scale tensor(-38.4312, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(81.7031, device='cuda:0')
simple_lm_proj.weight_scale tensor(84.3004, device='cuda:0')
------------- epoch 28 -------------
decoder.conv.weight_scale tensor(-38.8359, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(82.6944, device='cuda:0')
simple_lm_proj.weight_scale tensor(85.2972, device='cuda:0')
------------- epoch 29 -------------
decoder.conv.weight_scale tensor(-39.2581, device='cuda:0')
joiner.decoder_proj.weight_scale tensor(83.6404, device='cuda:0')
simple_lm_proj.weight_scale tensor(86.3040, device='cuda:0')

@danpovey
Copy link
Collaborator

danpovey commented Aug 18, 2022

Because it's sometimes possible to scale up one layer and scale down the next one without affecting the output, this can happen.
I suggest introducing a maximum for the weight_scale and enforcing it in the optimizer. That it is what I am doing right now
anyway, but it's not ready to merge.
Could do, for all scalar parameters in the optimizer, something like
param.clamp_(min=-10, max=2)
as a change to Eve.
[EDIT: I changed this to have min=-10, max=2, not -10..10, to better support float16 training. This is the useful range in normal models, anyway.]

@marcoyang1998
Copy link
Collaborator

I've added the clamping operation in the Eve optimizer and here are the WERs I got using the recipe pruned_transducer_stateless5 with and without clamping on LibriSpeech 100h. I set min to -10 and max to 2.

  test-clean test-other
No clamping greedy search@epoch-30-avg-10 6.70 17.27
With clamping greedy search@epoch-30-avg-10 6.73 17.31

It seems that the clamping operation does not affect training. Here are the values of three scalar weights after each epoch:

Without clamping:

Epoch: 1
Decoder conv weight_scale: 2.532698154449463
Joiner decoder_proj weight_scale: -1.3956085443496704
Simple lm_proj weight_scale: 0.014985241927206516
Epoch: 2
Decoder conv weight_scale: 2.4966623783111572
Joiner decoder_proj weight_scale: -1.3323917388916016
Simple lm_proj weight_scale: -0.04377373680472374
Epoch: 3
Decoder conv weight_scale: 2.408905506134033
Joiner decoder_proj weight_scale: -1.013999104499817
Simple lm_proj weight_scale: -0.248096764087677
Epoch: 4
Decoder conv weight_scale: 2.3934552669525146
Joiner decoder_proj weight_scale: -0.9088234305381775
Simple lm_proj weight_scale: -0.2796499729156494
Epoch: 5
Decoder conv weight_scale: 2.463165283203125
Joiner decoder_proj weight_scale: -1.2767215967178345
Simple lm_proj weight_scale: -0.09834159910678864
Epoch: 6
Decoder conv weight_scale: 2.5755698680877686
Joiner decoder_proj weight_scale: -1.4661520719528198
Simple lm_proj weight_scale: 0.08270139247179031
Epoch: 7
Decoder conv weight_scale: 2.453735113143921
Joiner decoder_proj weight_scale: -1.2604535818099976
Simple lm_proj weight_scale: -0.11372770369052887
Epoch: 8
Decoder conv weight_scale: 2.562387704849243
Joiner decoder_proj weight_scale: -1.4478679895401
Simple lm_proj weight_scale: 0.06404751539230347
Epoch: 9
Decoder conv weight_scale: 2.539534568786621
Joiner decoder_proj weight_scale: -1.40958833694458
Simple lm_proj weight_scale: 0.02723226696252823
Epoch: 10
Decoder conv weight_scale: 2.5854289531707764
Joiner decoder_proj weight_scale: -1.48295259475708
Simple lm_proj weight_scale: 0.09910309314727783
Epoch: 11
Decoder conv weight_scale: 2.4235291481018066
Joiner decoder_proj weight_scale: -1.1637898683547974
Simple lm_proj weight_scale: -0.19667033851146698
Epoch: 12
Decoder conv weight_scale: 2.5119268894195557
Joiner decoder_proj weight_scale: -1.362194299697876
Simple lm_proj weight_scale: -0.018590014427900314
Epoch: 13
Decoder conv weight_scale: 2.4722437858581543
Joiner decoder_proj weight_scale: -1.2908258438110352
Simple lm_proj weight_scale: -0.08399543911218643
Epoch: 14
Decoder conv weight_scale: 2.5198278427124023
Joiner decoder_proj weight_scale: -1.3738354444503784
Simple lm_proj weight_scale: -0.006699979770928621
Epoch: 15
Decoder conv weight_scale: 2.4350645542144775
Joiner decoder_proj weight_scale: -1.2236826419830322
Simple lm_proj weight_scale: -0.15010575950145721
Epoch: 16
Decoder conv weight_scale: 2.5459630489349365
Joiner decoder_proj weight_scale: -1.4185402393341064
Simple lm_proj weight_scale: 0.03665268048644066
Epoch: 17
Decoder conv weight_scale: 2.260655403137207
Joiner decoder_proj weight_scale: -0.8165771961212158
Simple lm_proj weight_scale: -0.5119346976280212
Epoch: 18
Decoder conv weight_scale: 2.5793089866638184
Joiner decoder_proj weight_scale: -1.4762474298477173
Simple lm_proj weight_scale: 0.0903153046965599
Epoch: 19
Decoder conv weight_scale: 2.5563783645629883
Joiner decoder_proj weight_scale: -1.4410549402236938
Simple lm_proj weight_scale: 0.055296555161476135
Epoch: 20
Decoder conv weight_scale: 2.5518383979797363
Joiner decoder_proj weight_scale: -1.4319567680358887
Simple lm_proj weight_scale: 0.04718981310725212
Epoch: 21
Decoder conv weight_scale: 2.443829298019409
Joiner decoder_proj weight_scale: -1.2422311305999756
Simple lm_proj weight_scale: -0.13266611099243164
Epoch: 22
Decoder conv weight_scale: 2.4006619453430176
Joiner decoder_proj weight_scale: -0.8165771961212158
Simple lm_proj weight_scale: -0.3719165623188019
Epoch: 23
Decoder conv weight_scale: 2.4226510524749756
Joiner decoder_proj weight_scale: -1.1069473028182983
Simple lm_proj weight_scale: -0.22385886311531067
Epoch: 24
Decoder conv weight_scale: 2.568734645843506
Joiner decoder_proj weight_scale: -1.4594826698303223
Simple lm_proj weight_scale: 0.0736895278096199
Epoch: 25
Decoder conv weight_scale: 2.4270377159118652
Joiner decoder_proj weight_scale: -1.1974050998687744
Simple lm_proj weight_scale: -0.17232902348041534
Epoch: 26
Decoder conv weight_scale: 2.489833116531372
Joiner decoder_proj weight_scale: -1.3193844556808472
Simple lm_proj weight_scale: -0.05526689812541008
Epoch: 27
Decoder conv weight_scale: 2.525768756866455
Joiner decoder_proj weight_scale: -1.3883850574493408
Simple lm_proj weight_scale: 0.004825897980481386
Epoch: 28
Decoder conv weight_scale: 2.4805877208709717
Joiner decoder_proj weight_scale: -1.3054426908493042
Simple lm_proj weight_scale: -0.07006088644266129
Epoch: 29
Decoder conv weight_scale: 2.366048574447632
Joiner decoder_proj weight_scale: -0.8165771961212158
Simple lm_proj weight_scale: -0.4065343141555786
Epoch: 30
Decoder conv weight_scale: 2.5044734477996826
Joiner decoder_proj weight_scale: -1.3446654081344604
Simple lm_proj weight_scale: -0.031459204852581024

With clamping:

Epoch: 1
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.1262197494506836
Simple lm_proj weight_scale: 1.346535086631775
Epoch: 2
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.13853846490383148
Simple lm_proj weight_scale: 1.2230230569839478
Epoch: 3
Decoder conv weight_scale: 1.9999276399612427
Joiner decoder_proj weight_scale: -0.30957549810409546
Simple lm_proj weight_scale: 0.7861649990081787
Epoch: 4
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.47377195954322815
Simple lm_proj weight_scale: 0.5668039321899414
Epoch: 5
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.15966753661632538
Simple lm_proj weight_scale: 1.1079970598220825
Epoch: 6
Decoder conv weight_scale: 1.9998482465744019
Joiner decoder_proj weight_scale: -0.11061322689056396
Simple lm_proj weight_scale: 1.4907336235046387
Epoch: 7
Decoder conv weight_scale: 1.9994806051254272
Joiner decoder_proj weight_scale: -0.17074772715568542
Simple lm_proj weight_scale: 1.074182152748108
Epoch: 8
Decoder conv weight_scale: 1.9996674060821533
Joiner decoder_proj weight_scale: -0.11555270105600357
Simple lm_proj weight_scale: 1.4526734352111816
Epoch: 9
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.12586870789527893
Simple lm_proj weight_scale: 1.371146321296692
Epoch: 10
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.10725975781679153
Simple lm_proj weight_scale: 1.5275452136993408
Epoch: 11
Decoder conv weight_scale: 1.999825358390808
Joiner decoder_proj weight_scale: -0.23376572132110596
Simple lm_proj weight_scale: 0.9220944046974182
Epoch: 12
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.1301610767841339
Simple lm_proj weight_scale: 1.2763099670410156
Epoch: 13
Decoder conv weight_scale: 1.9997310638427734
Joiner decoder_proj weight_scale: -0.15640747547149658
Simple lm_proj weight_scale: 1.1399130821228027
Epoch: 14
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.13257452845573425
Simple lm_proj weight_scale: 1.3028491735458374
Epoch: 15
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.19641730189323425
Simple lm_proj weight_scale: 1.002395510673523
Epoch: 16
Decoder conv weight_scale: 1.9998756647109985
Joiner decoder_proj weight_scale: -0.11937755346298218
Simple lm_proj weight_scale: 1.3922357559204102
Epoch: 17
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.8165771961212158
Simple lm_proj weight_scale: -0.21228034794330597
Epoch: 18
Decoder conv weight_scale: 1.9999603033065796
Joiner decoder_proj weight_scale: -0.1086028665304184
Simple lm_proj weight_scale: 1.5089606046676636
Epoch: 19
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.12050770223140717
Simple lm_proj weight_scale: 1.4338239431381226
Epoch: 20
Decoder conv weight_scale: 1.9998784065246582
Joiner decoder_proj weight_scale: -0.12354948371648788
Simple lm_proj weight_scale: 1.4136296510696411
Epoch: 21
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.1823224574327469
Simple lm_proj weight_scale: 1.037805199623108
Epoch: 22
Decoder conv weight_scale: 1.9997987747192383
Joiner decoder_proj weight_scale: -0.8165771961212158
Simple lm_proj weight_scale: 0.3626188039779663
Epoch: 23
Decoder conv weight_scale: 1.9996812343597412
Joiner decoder_proj weight_scale: -0.2594310939311981
Simple lm_proj weight_scale: 0.8760554790496826
Epoch: 24
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.11035250872373581
Simple lm_proj weight_scale: 1.4716155529022217
Epoch: 25
Decoder conv weight_scale: 1.9995653629302979
Joiner decoder_proj weight_scale: -0.21520820260047913
Simple lm_proj weight_scale: 0.9618409276008606
Epoch: 26
Decoder conv weight_scale: 1.9997872114181519
Joiner decoder_proj weight_scale: -0.13693207502365112
Simple lm_proj weight_scale: 1.196391224861145
Epoch: 27
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.12907470762729645
Simple lm_proj weight_scale: 1.3241920471191406
Epoch: 28
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.1475147157907486
Simple lm_proj weight_scale: 1.1672462224960327
Epoch: 29
Decoder conv weight_scale: 2.0
Joiner decoder_proj weight_scale: -0.8165771961212158
Simple lm_proj weight_scale: 0.10545660555362701
Epoch: 30
Decoder conv weight_scale: 1.9997196197509766
Joiner decoder_proj weight_scale: -0.13277317583560944
Simple lm_proj weight_scale: 1.250410556793213

It seems that one scalar weight is in saturation: the clamping function is pulling it back every time. The WER is fine though, so this might not be a problem. Do I need to try another clamping range e.g (-10,3)?

@danpovey
Copy link
Collaborator

danpovey commented Sep 7, 2022

That threshold is fine. These models have degrees of freedom, one can get big and the other small..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants