Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Autograd] Very serious bug of grad_req='add' #17989

Closed
sxjscience opened this issue Apr 7, 2020 · 22 comments · Fixed by #17995
Closed

[Autograd] Very serious bug of grad_req='add' #17989

sxjscience opened this issue Apr 7, 2020 · 22 comments · Fixed by #17995
Assignees

Comments

@sxjscience
Copy link
Member

sxjscience commented Apr 7, 2020

Minimal reproducible example:

import mxnet as mx
from mxnet.gluon import nn, HybridBlock
import numpy as np
import argparse
mx.npx.set_np()

np.random.seed(123)
mx.random.seed(123)


parser = argparse.ArgumentParser(
        description='Grad req bug minimal example')
parser.add_argument('--addto', action='store_true')
parser.add_argument('--hybridize', action='store_true')
parser.add_argument('--nrepeat', type=int, default=5)
args = parser.parse_args()

class Foo(HybridBlock):
    def __init__(self, prefix=None, params=None):
        super().__init__(prefix=prefix, params=params)
        with self.name_scope():
            self.layer = nn.Dense(16)

    def hybrid_forward(self, F, dat):
        out = dat
        for _ in range(args.nrepeat):
            out = self.layer(out)
        return out

foo = Foo()
if args.hybridize:
   foo.hybridize()
foo.initialize(ctx=mx.gpu())

if args.grad_addto:
    for p in foo.collect_params().values():
        p.grad_req = 'add'
foo.collect_params().zero_grad()


dat = mx.np.random.normal(0, 1, (32, 16), ctx=mx.gpu())
og = mx.np.random.normal(0, 1, (32, 16), ctx=mx.gpu())
with mx.autograd.record():
    out = foo(dat)
    loss = (out * og).sum()
    loss.backward()
for k, v in foo.collect_params().items():
    print(k, mx.np.linalg.norm(v.grad()))

The gradient will be different when add is triggered. However, we should initialize the gradient as all zero.

To reproduce the bug, we can use the following command:

new script with zero_grad

wget https://gist.githubusercontent.com/sxjscience/0bd336c921396b3c66331354e1866886/raw/d618ba69cbecf04d3013db77af86c29d62fe0336/grad_req_addto_bug.py -O grad_req_addto_bug.py
python grad_req_addto_bug.py --addto
python grad_req_addto_bug.py

The output is

ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --addto

foo0_dense0_weight 1.7802463
foo0_dense0_bias 315.05945
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py
foo0_dense0_weight 0.19310227
foo0_dense0_bias 19.947094
@sxjscience sxjscience added the Bug label Apr 7, 2020
@sxjscience
Copy link
Member Author

sxjscience commented Apr 7, 2020

Also, a deeper dive into the problem shows that the issue appears when one layer is reused for 5,6,7 times:

wget https://gist.githubusercontent.com/sxjscience/0bd336c921396b3c66331354e1866886/raw/80a428980fd91110455e847c1a02aef4ae2cba7f/grad_req_addto_bug.py -O grad_req_addto_bug.py
for nrepeat in 1 2 3 4 5 6 7 8 9 10
do
    echo "nrepeat=${nrepeat}"
    echo "with addto"
    python grad_req_addto_bug.py --addto --nrepeat ${nrepeat}
    echo "without addto"
    python grad_req_addto_bug.py --nrepeat ${nrepeat}
done

Result:

nrepeat=1
with addto
foo0_dense0_weight 86.363464
foo0_dense0_bias 19.548544
without addto
foo0_dense0_weight 86.363464
foo0_dense0_bias 19.548544
nrepeat=2
with addto
foo0_dense0_weight 21.480562
foo0_dense0_bias 19.870453
without addto
foo0_dense0_weight 21.480562
foo0_dense0_bias 19.870453
nrepeat=3
with addto
foo0_dense0_weight 4.64952
foo0_dense0_bias 19.938385
without addto
foo0_dense0_weight 4.64952
foo0_dense0_bias 19.938385
nrepeat=4
with addto
foo0_dense0_weight 0.94337225
foo0_dense0_bias 19.947392
without addto
foo0_dense0_weight 0.94337225
foo0_dense0_bias 19.947392
nrepeat=5
with addto
foo0_dense0_weight 1.7802463
foo0_dense0_bias 315.05945
without addto
foo0_dense0_weight 0.19310227
foo0_dense0_bias 19.947094
nrepeat=6
with addto
foo0_dense0_weight 0.6738244
foo0_dense0_bias 630.11865
without addto
foo0_dense0_weight 0.041728128
foo0_dense0_bias 19.946844
nrepeat=7
with addto
foo0_dense0_weight 0.26325437
foo0_dense0_bias 1260.2372
without addto
foo0_dense0_weight 0.009131842
foo0_dense0_bias 19.946758
nrepeat=8
with addto
foo0_dense0_weight 0.0020059107
foo0_dense0_bias 19.946749
without addto
foo0_dense0_weight 0.0020059107
foo0_dense0_bias 19.946749
nrepeat=9
with addto
foo0_dense0_weight 0.00045126013
foo0_dense0_bias 19.946749
without addto
foo0_dense0_weight 0.00045126013
foo0_dense0_bias 19.946749
nrepeat=10
with addto
foo0_dense0_weight 0.00010413639
foo0_dense0_bias 19.946749
without addto
foo0_dense0_weight 0.00010413639
foo0_dense0_bias 19.946749

This shows that it's only wrong when nrepeat is 5,6,7

@sxjscience
Copy link
Member Author

Also, adding zero_grad before the mx.autograd.record() won't solve this problem. I've revised the script in gist and you may try the new version:

wget https://gist.githubusercontent.com/sxjscience/0bd336c921396b3c66331354e1866886/raw/d618ba69cbecf04d3013db77af86c29d62fe0336/grad_req_addto_bug.py -O grad_req_addto_bug.py
python grad_req_addto_bug.py --addto
python grad_req_addto_bug.py
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --addto

foo0_dense0_weight 1.7802463
foo0_dense0_bias 315.05945
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py
foo0_dense0_weight 0.19310227
foo0_dense0_bias 19.947094

@sxjscience
Copy link
Member Author

I discovered this bug when trying to use different parameters of ALBERT. In the original ALBERT, the number of layers are 12 or 24. Both of them won't trigger the bug, so it took me some time to localize the issue.

ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --addto --nrepeat 12

foo0_dense0_weight 5.412447e-06
foo0_dense0_bias 19.946749
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --nrepeat 12
foo0_dense0_weight 5.412447e-06
foo0_dense0_bias 19.946749
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --addto --nrepeat 24

foo0_dense0_weight 5.706055e-14
foo0_dense0_bias 19.946749
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --nrepeat 24
foo0_dense0_weight 5.706055e-14

Also, the bug occurs in the hybridized case.

ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py --addto --nrepeat 5 --hybridize

foo0_dense0_weight 1.7802463
foo0_dense0_bias 315.05945
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug.py  --nrepeat 5 --hybridize
foo0_dense0_weight 0.19310227
foo0_dense0_bias 19.947094

It also appears in the legacy mx.nd interface:

import mxnet as mx
from mxnet.gluon import nn, HybridBlock
import numpy as np
import argparse

np.random.seed(123)
mx.random.seed(123)


parser = argparse.ArgumentParser(
        description='Grad req bug minimal example')
parser.add_argument('--addto', action='store_true')
parser.add_argument('--hybridize', action='store_true')
parser.add_argument('--nrepeat', type=int, default=5)
args = parser.parse_args()

class Foo(HybridBlock):
    def __init__(self, prefix=None, params=None):
        super().__init__(prefix=prefix, params=params)
        with self.name_scope():
            self.layer = nn.Dense(16)

    def hybrid_forward(self, F, dat):
        out = dat
        for _ in range(args.nrepeat):
            out = self.layer(out)
        return out

foo = Foo()
if args.hybridize:
   foo.hybridize()
foo.initialize(ctx=mx.gpu())

if args.addto:
    for p in foo.collect_params().values():
        p.grad_req = 'add'



dat = mx.nd.random.normal(0, 1, (32, 16), ctx=mx.gpu())
og = mx.nd.random.normal(0, 1, (32, 16), ctx=mx.gpu())
with mx.autograd.record():
    out = foo(dat)
    loss = (out * og).sum()
    loss.backward()
for k, v in foo.collect_params().items():
    print(k, mx.nd.norm(v.grad()))
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug_nd.py  --nrepeat 5 --hybridize
foo0_dense0_weight 
[0.16300175]
<NDArray 1 @gpu(0)>
foo0_dense0_bias 
[27.344622]
<NDArray 1 @gpu(0)>
ubuntu@ip-172-31-27-255:~$ python grad_req_addto_bug_nd.py  --nrepeat 5 --hybridize --addto
foo0_dense0_weight 
[1.3425881]
<NDArray 1 @gpu(0)>
foo0_dense0_bias 
[424.70026]
<NDArray 1 @gpu(0)>

@sxjscience
Copy link
Member Author

Just verified that there is no problem when ctx=mx.cpu() . Also, I've found a simpler script to reproduce the problem:

import mxnet as mx
import numpy as np
mx.npx.set_np()


for ctx in [mx.cpu(), mx.gpu()]:
    for nrepeat in range(1, 10):
        stored_grad = dict()
        for grad_req in ['write', 'add']:
            a = mx.np.array(1, ctx=ctx)
            b = mx.np.array(2, ctx=ctx)
            if grad_req == 'write':
                a.attach_grad(grad_req='write')
            elif grad_req == 'add':
                a.attach_grad(grad_req='add')
            a.grad[()] = 0
            with mx.autograd.record():
                for _ in range(nrepeat):
                    b = b * a
                b.backward()
            stored_grad[grad_req] = a.grad.asnumpy()
        print('ctx={}, nrepeat={}, write={}, add={}'.format(ctx, nrepeat, stored_grad['write'], stored_grad['add']))

Result:

ctx=cpu(0), nrepeat=1, write=2.0, add=2.0
ctx=cpu(0), nrepeat=2, write=4.0, add=4.0
ctx=cpu(0), nrepeat=3, write=6.0, add=6.0
ctx=cpu(0), nrepeat=4, write=8.0, add=8.0
ctx=cpu(0), nrepeat=5, write=10.0, add=10.0
ctx=cpu(0), nrepeat=6, write=12.0, add=12.0
ctx=cpu(0), nrepeat=7, write=14.0, add=14.0
ctx=cpu(0), nrepeat=8, write=16.0, add=16.0
ctx=cpu(0), nrepeat=9, write=18.0, add=18.0
ctx=gpu(0), nrepeat=1, write=2.0, add=2.0
ctx=gpu(0), nrepeat=2, write=4.0, add=4.0
ctx=gpu(0), nrepeat=3, write=6.0, add=6.0
ctx=gpu(0), nrepeat=4, write=8.0, add=8.0
ctx=gpu(0), nrepeat=5, write=10.0, add=62.0
ctx=gpu(0), nrepeat=6, write=12.0, add=126.0
ctx=gpu(0), nrepeat=7, write=14.0, add=254.0
ctx=gpu(0), nrepeat=8, write=16.0, add=16.0
ctx=gpu(0), nrepeat=9, write=18.0, add=18.0

@sxjscience
Copy link
Member Author

@eric-haibin-lin @szha @szhengac @zhreshold This is the worst problem I've found and it impacts all models with grad_req=add.

@zheyuye
Copy link
Contributor

zheyuye commented Apr 7, 2020

This bug does affect many models with parameters sharing.

@szha
Copy link
Member

szha commented Apr 7, 2020

@ptrendx both this issue and the #16708 seem to be GPU-specific. Would you mind taking a look?

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

Sure

@ptrendx ptrendx self-assigned this Apr 7, 2020
@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

Hmm, I just tried the latest script from @sxjscience and I got exactly opposite results - gpu working as expected and cpu not working right:

ctx=cpu(0), nrepeat=1, write=2.0, add=2.0
ctx=cpu(0), nrepeat=2, write=4.0, add=4.0
ctx=cpu(0), nrepeat=3, write=6.0, add=6.0
ctx=cpu(0), nrepeat=4, write=8.0, add=8.0
ctx=cpu(0), nrepeat=5, write=10.0, add=62.0
ctx=cpu(0), nrepeat=6, write=12.0, add=126.0
ctx=cpu(0), nrepeat=7, write=14.0, add=254.0
ctx=cpu(0), nrepeat=8, write=16.0, add=16.0
ctx=cpu(0), nrepeat=9, write=18.0, add=18.0
ctx=gpu(0), nrepeat=1, write=2.0, add=2.0
ctx=gpu(0), nrepeat=2, write=4.0, add=4.0
ctx=gpu(0), nrepeat=3, write=6.0, add=6.0
ctx=gpu(0), nrepeat=4, write=8.0, add=8.0
ctx=gpu(0), nrepeat=5, write=10.0, add=10.0
ctx=gpu(0), nrepeat=6, write=12.0, add=12.0
ctx=gpu(0), nrepeat=7, write=14.0, add=14.0
ctx=gpu(0), nrepeat=8, write=16.0, add=16.0
ctx=gpu(0), nrepeat=9, write=18.0, add=18.0

I ran it ~15 times and always got the same result. What is the version of MXNet you tried it on, @sxjscience?

@sxjscience
Copy link
Member Author

@ptrendx @szha @zhreshold I find that the bug also exists in 1.5.0, 1.4.0, 1.3.1, 1.2.1. In fact, results on both CPU and GPU are wrong in these versions. Reproducible script is given as follows (I used the legacy mx.nd).

import mxnet as mx
import numpy as np


for ctx in [mx.cpu(), mx.gpu()]:
    for nrepeat in range(1, 10):
        stored_grad = dict()
        for grad_req in ['write', 'add']:
            a = mx.nd.array([1], ctx=ctx)
            b = mx.nd.array([2], ctx=ctx)
            if grad_req == 'write':
                a.attach_grad(grad_req='write')
            elif grad_req == 'add':
                a.attach_grad(grad_req='add')
            a.grad[:] = 0
            with mx.autograd.record():
                for _ in range(nrepeat):
                    b = b * a
                b.backward()
            stored_grad[grad_req] = a.grad.asscalar()
        print('ctx={}, nrepeat={}, write={}, add={}'.format(ctx, nrepeat, stored_grad['write'], stored_grad['add']))

For MXNet 1.5.0, I used pip install mxnet-cu101==1.5.0
For MXNet 1.4.0, I used pip install mxnet-cu92==1.4.0
For MXNet 1.3.1, I used pip install mxnet-cu92==1.3.1
For MXNet 1.2.1, I used pip install mxnet-cu92==1.2.1

Output

ctx=cpu(0), nrepeat=1, write=2.0, add=2.0
ctx=cpu(0), nrepeat=2, write=4.0, add=4.0
ctx=cpu(0), nrepeat=3, write=6.0, add=6.0
ctx=cpu(0), nrepeat=4, write=8.0, add=8.0
ctx=cpu(0), nrepeat=5, write=10.0, add=62.0
ctx=cpu(0), nrepeat=6, write=12.0, add=126.0
ctx=cpu(0), nrepeat=7, write=14.0, add=254.0
ctx=cpu(0), nrepeat=8, write=16.0, add=16.0
ctx=cpu(0), nrepeat=9, write=18.0, add=18.0
ctx=gpu(0), nrepeat=1, write=2.0, add=2.0
ctx=gpu(0), nrepeat=2, write=4.0, add=4.0
ctx=gpu(0), nrepeat=3, write=6.0, add=6.0
ctx=gpu(0), nrepeat=4, write=8.0, add=8.0
ctx=gpu(0), nrepeat=5, write=10.0, add=62.0
ctx=gpu(0), nrepeat=6, write=12.0, add=126.0
ctx=gpu(0), nrepeat=7, write=14.0, add=254.0
ctx=gpu(0), nrepeat=8, write=16.0, add=16.0
ctx=gpu(0), nrepeat=9, write=18.0, add=18.0

@sxjscience
Copy link
Member Author

@ptrendx @zhreshold @szha I tried to run with MXNet==1.0.0 but it give me another error. The earliest version I can confirm that has this issue is 1.2.0. This is really critical and impacts the very basic functionality of a DL framework, i.e., autograd.

@sxjscience
Copy link
Member Author

@ptrendx I'm using a compiled version of master. Are you able to reproduce it using the script I attached at the beginning of the issue?

wget https://gist.githubusercontent.com/sxjscience/0bd336c921396b3c66331354e1866886/raw/d618ba69cbecf04d3013db77af86c29d62fe0336/grad_req_addto_bug.py -O grad_req_addto_bug.py
python grad_req_addto_bug.py --addto
python grad_req_addto_bug.py

@sxjscience sxjscience changed the title [Gradient Addto] Very serious bug of grad_req='add' [Autograd] Very serious bug of grad_req='add' Apr 7, 2020
@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

I first tried on our container (which is based on 1.6.0), since that is the easiest thing for me to try first. When I tried both your numpy and ndarray small examples I see CPU exhibiting the error and GPU not, so looking into what could be different between those implementations.

@sxjscience
Copy link
Member Author

May be it's not related to specific implementation in the GPU side.

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

It is ElementwiseSum, although I'm not sure why after 7 repeats you get correct result again.
The code of ElementwiseSum for more than 4 repetitions:

    default: {
      DType* in_0_dptr = in_data[0].dptr<DType>();
      Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr);
      for (size_t i = 1; i < size; ++i) {
        DType* in_dptr = in_data[i].dptr<DType>();
        Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], out_dptr, in_dptr);
      }
      break;
    }

which is wrong - if req[0] is kAddTo (which it will be in this case), after the first kernel you get

out = out + in_0

but then the subsequent ones instead of doing

out = out + in_i

do

out += out + in_i

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

I don't see this issue in our (not yet released) container because I changed ElementwiseSum implementation to be vectorized, and there I do have the proper logic:

      for (size_t i = 0; i < inputs.size(); i += num_inputs_per_kernel) {
        if (i == 0) {
          using Kernel = VectorizedElementwiseSumFwd<DType, Req>;
          typename Kernel::ParamType params;
          params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i);
          for (int j = 0; j < params.num_inputs; ++j) {
            params.inputs[j] = inputs[i + j].dptr<DType>();
          }
          params.outputs[0] = outputs[0].dptr<DType>();
          VectorizedKernelLauncher<DType, LType, Kernel>(size, s, params);
        } else {
          /* During subsequent launches we need to
             accumulate into the previous outputs
          */
          using Kernel = VectorizedElementwiseSumFwd<DType, kAddTo>;
          typename Kernel::ParamType params;
          params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i);
          for (int j = 0; j < params.num_inputs; ++j) {
            params.inputs[j] = inputs[i + j].dptr<DType>();
          }
          params.outputs[0] = outputs[0].dptr<DType>();
          VectorizedKernelLauncher<DType, LType, Kernel>(size, s, params);
        }
      }

where I change the sum to do kAddTo and not including the output as one of the parameters.

@sxjscience
Copy link
Member Author

@ptrendx Thanks! I think that explains the cause.

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

It does not explain why it does the right thing for nrepeat=8 forward. There has to be something else going on there that limits elementwisesum to 7 inputs somehow.

@zhreshold
Copy link
Member

I don't think it's related to particular op implementation, it's something may not be working at all when autograd is introduced.
See the experiments I did: #16708 (comment)

What I did is replicate the same node N times, if N is in (1, 2, 3, 4, 8, 9, 10...) times, the loss and gradients are always GOOD, however, with (5, 6, 7), the gradients will diverge at the first iteration

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

@zhreshold And for any op that you check you introduce the elemwisesum node in the backward pass that aggregates the gradients, which you do not see in the model. As I said, I still do not understand why 8+ is fine, I would expect everything above 4 to fail (as elemwisesum implementation has special cases for 2,3 and 4, and then buggy one for 5+). I will look into it further.

@sxjscience
Copy link
Member Author

@ptrendx After checking the source code, I think it's due to the MXNET_EXEC_INPLACE_GRAD_SUM_CAP setting to 8 by default. The ElementwiseSum is used in the GraphExecutor to accumulate the gradient:
https://github.com/apache/incubator-mxnet/blob/79c576b8157539d365cc9e0e1e355d4ca12f7374/src/executor/graph_executor.cc#L255-L263

And the inplace_sum_cap is by default 8:
https://github.com/apache/incubator-mxnet/blob/79c576b8157539d365cc9e0e1e355d4ca12f7374/src/executor/graph_executor.cc#L228

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2020

Yup, setting that env variable (undocumented BTW ;-) ) to higher value makes the test fail for the higher cases too.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants