-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Autograd] Very serious bug of grad_req='add' #17989
Comments
Also, a deeper dive into the problem shows that the issue appears when one layer is reused for 5,6,7 times: wget https://gist.githubusercontent.com/sxjscience/0bd336c921396b3c66331354e1866886/raw/80a428980fd91110455e847c1a02aef4ae2cba7f/grad_req_addto_bug.py -O grad_req_addto_bug.py
for nrepeat in 1 2 3 4 5 6 7 8 9 10
do
echo "nrepeat=${nrepeat}"
echo "with addto"
python grad_req_addto_bug.py --addto --nrepeat ${nrepeat}
echo "without addto"
python grad_req_addto_bug.py --nrepeat ${nrepeat}
done Result:
This shows that it's only wrong when |
Also, adding zero_grad before the wget https://gist.githubusercontent.com/sxjscience/0bd336c921396b3c66331354e1866886/raw/d618ba69cbecf04d3013db77af86c29d62fe0336/grad_req_addto_bug.py -O grad_req_addto_bug.py
python grad_req_addto_bug.py --addto
python grad_req_addto_bug.py
|
I discovered this bug when trying to use different parameters of ALBERT. In the original ALBERT, the number of layers are 12 or 24. Both of them won't trigger the bug, so it took me some time to localize the issue.
Also, the bug occurs in the hybridized case.
It also appears in the legacy import mxnet as mx
from mxnet.gluon import nn, HybridBlock
import numpy as np
import argparse
np.random.seed(123)
mx.random.seed(123)
parser = argparse.ArgumentParser(
description='Grad req bug minimal example')
parser.add_argument('--addto', action='store_true')
parser.add_argument('--hybridize', action='store_true')
parser.add_argument('--nrepeat', type=int, default=5)
args = parser.parse_args()
class Foo(HybridBlock):
def __init__(self, prefix=None, params=None):
super().__init__(prefix=prefix, params=params)
with self.name_scope():
self.layer = nn.Dense(16)
def hybrid_forward(self, F, dat):
out = dat
for _ in range(args.nrepeat):
out = self.layer(out)
return out
foo = Foo()
if args.hybridize:
foo.hybridize()
foo.initialize(ctx=mx.gpu())
if args.addto:
for p in foo.collect_params().values():
p.grad_req = 'add'
dat = mx.nd.random.normal(0, 1, (32, 16), ctx=mx.gpu())
og = mx.nd.random.normal(0, 1, (32, 16), ctx=mx.gpu())
with mx.autograd.record():
out = foo(dat)
loss = (out * og).sum()
loss.backward()
for k, v in foo.collect_params().items():
print(k, mx.nd.norm(v.grad()))
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug_nd.py --nrepeat 5 --hybridize
foo0_dense0_weight
[0.16300175]
<NDArray 1 @gpu(0)>
foo0_dense0_bias
[27.344622]
<NDArray 1 @gpu(0)>
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug_nd.py --nrepeat 5 --hybridize --addto
foo0_dense0_weight
[1.3425881]
<NDArray 1 @gpu(0)>
foo0_dense0_bias
[424.70026]
<NDArray 1 @gpu(0)> |
Just verified that there is no problem when import mxnet as mx
import numpy as np
mx.npx.set_np()
for ctx in [mx.cpu(), mx.gpu()]:
for nrepeat in range(1, 10):
stored_grad = dict()
for grad_req in ['write', 'add']:
a = mx.np.array(1, ctx=ctx)
b = mx.np.array(2, ctx=ctx)
if grad_req == 'write':
a.attach_grad(grad_req='write')
elif grad_req == 'add':
a.attach_grad(grad_req='add')
a.grad[()] = 0
with mx.autograd.record():
for _ in range(nrepeat):
b = b * a
b.backward()
stored_grad[grad_req] = a.grad.asnumpy()
print('ctx={}, nrepeat={}, write={}, add={}'.format(ctx, nrepeat, stored_grad['write'], stored_grad['add'])) Result:
|
@eric-haibin-lin @szha @szhengac @zhreshold This is the worst problem I've found and it impacts all models with |
This bug does affect many models with parameters sharing. |
Sure |
Hmm, I just tried the latest script from @sxjscience and I got exactly opposite results - gpu working as expected and cpu not working right:
I ran it ~15 times and always got the same result. What is the version of MXNet you tried it on, @sxjscience? |
@ptrendx @szha @zhreshold I find that the bug also exists in 1.5.0, 1.4.0, 1.3.1, 1.2.1. In fact, results on both CPU and GPU are wrong in these versions. Reproducible script is given as follows (I used the legacy mx.nd). import mxnet as mx
import numpy as np
for ctx in [mx.cpu(), mx.gpu()]:
for nrepeat in range(1, 10):
stored_grad = dict()
for grad_req in ['write', 'add']:
a = mx.nd.array([1], ctx=ctx)
b = mx.nd.array([2], ctx=ctx)
if grad_req == 'write':
a.attach_grad(grad_req='write')
elif grad_req == 'add':
a.attach_grad(grad_req='add')
a.grad[:] = 0
with mx.autograd.record():
for _ in range(nrepeat):
b = b * a
b.backward()
stored_grad[grad_req] = a.grad.asscalar()
print('ctx={}, nrepeat={}, write={}, add={}'.format(ctx, nrepeat, stored_grad['write'], stored_grad['add'])) For MXNet 1.5.0, I used Output
|
@ptrendx @zhreshold @szha I tried to run with MXNet==1.0.0 but it give me another error. The earliest version I can confirm that has this issue is 1.2.0. This is really critical and impacts the very basic functionality of a DL framework, i.e., autograd. |
@ptrendx I'm using a compiled version of master. Are you able to reproduce it using the script I attached at the beginning of the issue?
|
I first tried on our container (which is based on 1.6.0), since that is the easiest thing for me to try first. When I tried both your numpy and ndarray small examples I see CPU exhibiting the error and GPU not, so looking into what could be different between those implementations. |
May be it's not related to specific implementation in the GPU side. |
It is ElementwiseSum, although I'm not sure why after 7 repeats you get correct result again. default: {
DType* in_0_dptr = in_data[0].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr);
for (size_t i = 1; i < size; ++i) {
DType* in_dptr = in_data[i].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], out_dptr, in_dptr);
}
break;
} which is wrong - if
but then the subsequent ones instead of doing
do
|
I don't see this issue in our (not yet released) container because I changed ElementwiseSum implementation to be vectorized, and there I do have the proper logic: for (size_t i = 0; i < inputs.size(); i += num_inputs_per_kernel) {
if (i == 0) {
using Kernel = VectorizedElementwiseSumFwd<DType, Req>;
typename Kernel::ParamType params;
params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i);
for (int j = 0; j < params.num_inputs; ++j) {
params.inputs[j] = inputs[i + j].dptr<DType>();
}
params.outputs[0] = outputs[0].dptr<DType>();
VectorizedKernelLauncher<DType, LType, Kernel>(size, s, params);
} else {
/* During subsequent launches we need to
accumulate into the previous outputs
*/
using Kernel = VectorizedElementwiseSumFwd<DType, kAddTo>;
typename Kernel::ParamType params;
params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i);
for (int j = 0; j < params.num_inputs; ++j) {
params.inputs[j] = inputs[i + j].dptr<DType>();
}
params.outputs[0] = outputs[0].dptr<DType>();
VectorizedKernelLauncher<DType, LType, Kernel>(size, s, params);
}
} where I change the sum to do |
@ptrendx Thanks! I think that explains the cause. |
It does not explain why it does the right thing for nrepeat=8 forward. There has to be something else going on there that limits elementwisesum to 7 inputs somehow. |
I don't think it's related to particular op implementation, it's something may not be working at all when autograd is introduced. What I did is replicate the same node N times, if N is in (1, 2, 3, 4, 8, 9, 10...) times, the loss and gradients are always GOOD, however, with (5, 6, 7), the gradients will diverge at the first iteration |
@zhreshold And for any op that you check you introduce the elemwisesum node in the backward pass that aggregates the gradients, which you do not see in the model. As I said, I still do not understand why 8+ is fine, I would expect everything above 4 to fail (as elemwisesum implementation has special cases for 2,3 and 4, and then buggy one for 5+). I will look into it further. |
@ptrendx After checking the source code, I think it's due to the And the inplace_sum_cap is by default 8: |
Yup, setting that env variable (undocumented BTW ;-) ) to higher value makes the test fail for the higher cases too. |
Minimal reproducible example:
The gradient will be different when
add
is triggered. However, we should initialize the gradient as all zero.To reproduce the bug, we can use the following command:
new script with zero_grad
The output is
The text was updated successfully, but these errors were encountered: