Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Numpy] MobileBERT SQuAD training cannot reproduce the previous results #1322

Closed
sxjscience opened this issue Aug 27, 2020 · 23 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@sxjscience
Copy link
Member

sxjscience commented Aug 27, 2020

@zheyuye (FYI @szhengac )

I run the MobileBERT training on SQuAD again and the log is significantly different from the log reported in https://gluon-nlp-log.s3.amazonaws.com/squad_training_log/fintune_google_uncased_mobilebert_squad_2.0/finetune_squad2.0.log

To reproduce, just install the master-versino of GluonNLP and try the command in https://github.com/dmlc/gluon-nlp/blob/master/scripts/question_answering/commands/run_squad2_mobilebert.sh.

Date MXNet version GluonNLP commit Result Note
8/20 2.0.0b20200810 32e87d4 {"exact": 39.99831550576939, "f1": 43.640107609141324, "total": 11873, "HasAns_exact": 80.11133603238866, "HasAns_f1": 87.40536397492154, "HasAns_total": 5928, "NoAns_exact": 0.0, "NoAns_f1": 0.0, "NoAns_total": 5945, "best_exact": 74.10932367556642, "best_exact_thresh": -1.3667802810668945, "best_f1": 76.90326568723887, "best_f1_thresh": -1.2948681116104126, "best_ckpt": "google_uncased_mobilebert_squad2.0_20615.params"} Zheyu's initial run
8/27 @sxjscience 970318d {"exact": 28.73747157415986, "f1": 34.911053127777336, "total": 11873, "HasAns_exact": 57.557354925775975, "HasAns_f1": 69.92222229859992, "HasAns_total": 5928, "NoAns_exact": 0.0, "NoAns_f1": 0.0, "NoAns_total": 5945, "best_exact": 51.56236839888823, "best_exact_thresh": -1.2767648696899414, "best_f1": 53.27946001481111, "best_f1_thresh": -1.2767648696899414, "best_ckpt": "google_uncased_mobilebert_squad2.0_20615.params"} Xingjian's run
9/1 apache/mxnet#19044 latest The best evaluated results are {"exact": 41.27853112103091, "f1": 44.679597779846105, "total": 11873, "HasAns_exact": 82.67543859649123, "HasAns_f1": 89.4873253104104, "HasAns_total": 5928, "NoAns_exact": 0.0, "NoAns_f1": 0.0, "NoAns_total": 5945, "best_exact": 77.80678851174935, "best_exact_thresh": -1.9967657327651978, "best_f1": 80.54318992216594, "best_f1_thresh": -1.994931697845459, "best_ckpt": "google_uncased_mobilebert_squad2.0_20615.params"} Shuai's run
@sxjscience sxjscience added the bug Something isn't working label Aug 27, 2020
@szha
Copy link
Member

szha commented Aug 27, 2020

let's try to bisect where the problem occurred first. @zheyuye could you share which commit of gluonnlp and which version of mxnet you used for producing the above result?

@szhengac
Copy link
Member

@sxjscience Can you share your rerun log as well? Also how long does it take for one single run?

@sxjscience
Copy link
Member Author

It takes 4 hours on a g4-12dn. The rerun log: https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/debug/finetune_squad2.0_mobilebert_20200827.log You may check the best_exact and best_f1.

@zheyuye
Copy link
Member

zheyuye commented Aug 28, 2020

If I mistake not, latest gluonnlp and mxnet 0810 were used

@zheyuye
Copy link
Member

zheyuye commented Aug 28, 2020

Judging from the log, the reproduction version seems to converge too slowly. How about the other pre-trained models and Electra smalll wont take too long

@sxjscience
Copy link
Member Author

The others work fine.

@szhengac
Copy link
Member

szhengac commented Aug 28, 2020

why the two logs have basically the same grad norm for the last few logging points but very different losses?

@szhengac
Copy link
Member

i have tried mxnet 0810 and 0826, but none of them can reproduce the result.

@szha
Copy link
Member

szha commented Aug 29, 2020

Zheyu's log says the run happened on 8/20, so maybe it was an earlier GluonNLP commit that was working.

@sxjscience
Copy link
Member Author

I think it may also be the random seed. We can find that the gradient norm is super-large for the early runs. I'm actually considering how to add proper gradient scaling to the MobileBERT and recheck the conversion script. @zheyuye Would you know which version of GluonNLP were you using when producing this log?

@sxjscience
Copy link
Member Author

sxjscience commented Aug 29, 2020

Conversion script is here: https://github.com/dmlc/gluon-nlp/blob/master/scripts/conversion_toolkits/convert_mobilebert.sh. We will need to re-verify the conversion to see if there is any issues.

@zheyuye
Copy link
Member

zheyuye commented Aug 29, 2020

I think it may also be the random seed. We can find that the gradient norm is super-large for the early runs. I'm actually considering how to add proper gradient scaling to the MobileBERT and recheck the conversion script. @zheyuye Would you know which version of GluonNLP were you using when producing this log?

The PR and branch related to this issue https://github.com/ZheyuYe/gluon-nlp/commits/batch says that it was based on d8b68c6. In fact, random seeds can be another potential factor, and in general I would use 10 or 28 as a seed instead of the default value 100.

@szhengac
Copy link
Member

szhengac commented Aug 29, 2020

I have tried several versions, but none of them can reproduce the result. I have traced back to commits d8b68c6 on Aug 20 and 32e87d4 on Aug 14. For 9e268c0 on Aug 8, the code is not runnable.

@sxjscience
Copy link
Member Author

I think the gradient for the early iterations are unreasonably large and we may consider to investigate that.

@szhengac
Copy link
Member

The problem is that no version can reproduce the result.

@szha
Copy link
Member

szha commented Aug 30, 2020

I added a markdown table in the first post for summarizing our observations. Feel free to directly edit.

@sxjscience
Copy link
Member Author

@szha @szhengac I think we should check the conversion and the training script of MobileBERT again to solve the issue. Also, the problem is that MXNet is not reproducible even if we specify the random seed. We may change the defaults related to seeds as recommended in apache/mxnet#18987.

@sxjscience
Copy link
Member Author

In fact, I can confirm that the forward check introduced in

if test_conversion:
tf_contextual_embedding = tf_token_outputs_np['sequence_output']
tf_pooled_output = tf_token_outputs_np['pooled_output']
contextual_embedding, pooled_output = model.backbone_model(
mx_input_ids, mx_token_types, mx_valid_length)
assert_allclose(pooled_output.asnumpy(), tf_pooled_output, 1E-2, 1E-2)
for i in range(batch_size):
ele_valid_length = valid_length[i]
assert_allclose(contextual_embedding[i, :ele_valid_length, :].asnumpy(),
tf_contextual_embedding[i, :ele_valid_length, :], 1E-2, 1E-2)
has passed. May be the gradient is wrong.

@sxjscience
Copy link
Member Author

@szha @szhengac @zheyuye I noticed that the gradient of mx.np.pad seems to be wrong, which causes the problem of MobileBERT, which uses the trigram embedding implemented via pad:

if self._layout == 'NT':
word_embedding = F.np.concatenate(
[F.np.pad(word_embedding[:, 1:], ((0, 0), (0, 1), (0, 0))),
word_embedding,
F.np.pad(word_embedding[:, :-1], ((0, 0), (1, 0), (0, 0)))], axis=-1)
elif self._layout == 'TN':
word_embedding = F.np.concatenate(
[F.np.pad(word_embedding[1:, :], ((0, 1), (0, 0), (0, 0))),
word_embedding,
F.np.pad(word_embedding[:-1, :], ((1, 0), (0, 0), (0, 0)))], axis=-1)

Minimal Reproducible Example

MXNet Implementation:

import mxnet as mx
mx.npx.set_np()

ctx = mx.gpu()
a = mx.np.ones((3, 3, 3), ctx=ctx)
mult = np.random.normal(0, 1, (3, 3, 3))
a.attach_grad()
with mx.autograd.record():
    b = mx.np.pad(a[:, 1:], ((0, 0), (0, 1), (0, 0))) * mx.np.array(mult, ctx=ctx)
    b = b.sum()
b.backward()
print(a.grad)

Output:

[[[0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]]] @gpu(0)

Jax Implementation:

from jax import grad
import jax.numpy as jnp
import numpy as np
mult = np.random.normal(0, 1, (3, 3, 3))

a = jnp.ones((3, 3, 3))

def f(x):
    b = jnp.pad(x[:, 1:], ((0, 0), (0, 1), (0, 0))) * jnp.array(mult)
    return b.sum()
print(grad(f)(a))

Output:

[[[ 0.          0.          0.        ]
  [ 0.3545383  -0.84326786 -0.31482664]
  [ 1.0994871  -1.230104    2.8007567 ]]

 [[ 0.          0.          0.        ]
  [ 1.0447861  -0.16119051 -0.39860427]
  [-0.7756538   0.5314936   1.4601654 ]]

 [[ 0.          0.          0.        ]
  [ 0.37878916 -2.0777514   0.96676654]
  [ 0.45230922  0.3094176  -0.43687683]]]

@sxjscience
Copy link
Member Author

Discussed offline with @cassinixu , fixing the pad operator in the MXNet side requires some time. Meanwhile, a simple fix is to use conv1d with padding to implement the code in

if self._layout == 'NT':
word_embedding = F.np.concatenate(
[F.np.pad(word_embedding[:, 1:], ((0, 0), (0, 1), (0, 0))),
word_embedding,
F.np.pad(word_embedding[:, :-1], ((0, 0), (1, 0), (0, 0)))], axis=-1)
elif self._layout == 'TN':
word_embedding = F.np.concatenate(
[F.np.pad(word_embedding[1:, :], ((0, 1), (0, 0), (0, 0))),
word_embedding,
F.np.pad(word_embedding[:-1, :], ((1, 0), (0, 0), (0, 0)))], axis=-1)

@zheyuye Would you try this approach?

@sxjscience
Copy link
Member Author

Basically, we can use mx.npx.convolution, which has the same behavior as mx.nd.Convolution.

@szhengac
Copy link
Member

szhengac commented Sep 1, 2020

Confirm that apache/mxnet#19044 fixed the bug. Closing this issue.

@szhengac szhengac closed this as completed Sep 1, 2020
@sxjscience
Copy link
Member Author

@szhengac Would you submit a PR to update the SQuAD v1 + SQuAD v2 results of MobileBERT?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants