Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Flax pt-flax equivalence test more aggressive #15841

Merged
merged 14 commits into from
Mar 18, 2022

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Feb 27, 2022

What does this PR do?

Make Flax pt-flax equivalence test more aggressive. (Similar to #15839 for PT/TF).

It uses output_hidden_states=True and output_attentions=True to test all output tensors (in a recursive way).

Also, it lowers the tolerance from 4e-2 to 1e-5. (From the experience I gained in PT/TF test, if an error > 1e-5, I always found a bug to fix).

(A bit) surprisingly, but very good news: unlike PT/TF, there is no PT/Flax inconsistency found by this more aggressive test! (@patil-suraj must have done a great job on flax models :-) )

Flax: @patil-suraj @patrickvonplaten

@HuggingFaceDocBuilder
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ydshieh ydshieh force-pushed the aggressive_pt_flax_equiv_test branch from 068a659 to eefded6 Compare February 28, 2022 13:36
@ydshieh ydshieh changed the title [WIP] Make Flax pt-flax equivalence test more aggressive Make Flax pt-flax equivalence test more aggressive Feb 28, 2022
@ydshieh ydshieh marked this pull request as ready for review February 28, 2022 13:36
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this! Left a comment about tolerance value. I'm actually surprised that all tests are passing with 1e-5. Did you test this on GPU ?

@@ -160,12 +160,58 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

def check_outputs(self, fxo, pto, model_class, names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) could we call this fx_outputs and pt_outputs instead of fxo and pto ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, you and Sylvain told me the same thing :)

Comment on lines 165 to 180
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.

names: A string, or a list of strings. These specify what fxo/pto represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) maybe also document fxo and pto

fxo[pt_nans] = 0

max_diff = np.amax(np.abs(fxo - pto))
self.assertLessEqual(max_diff, 1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if 1e-5 will work for all models especially on TPU/GPU since JAX does some approximations on TPU so the output can diverge. cf #15754
What do you think @patrickvonplaten

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check on GPU VM - currently I am doing this for PT/TF.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think a precision of 1e-3 would be better

Comment on lines 210 to 213
# Pure convolutional models have no attention
# TODO: use a better and general criteria
if "FlaxConvNext" not in model_class.__name__:
config.output_attentions = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvNext is not available in flax (yet!)

Copy link
Collaborator Author

@ydshieh ydshieh Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I know :), I am just copying from PT/TF. Do you want me to remove it for now?)

@@ -160,12 +160,58 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

def check_outputs(self, fxo, pto, model_class, names):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need docstrings for test functions

config.output_hidden_states = True
# Pure convolutional models have no attention
# TODO: use a better and general criteria
if "FlaxConvNext" not in model_class.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This we can delete

@patrickvonplaten
Copy link
Contributor

Also very surprised that the tests are all passing. @ydshieh - can you double check that the tests are actually passing for most models. Some edge-case models that could be tested locally:

  • BigBird
  • Pegagus
  • GPT2

@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 3, 2022

@patil-suraj @patrickvonplaten

There are a few things to fix in this PR (same for the PT/TF) - some tests are just ignored by my mistakes.
I will let you know once I fix them.

@ydshieh ydshieh marked this pull request as draft March 3, 2022 15:40
@ydshieh ydshieh changed the title Make Flax pt-flax equivalence test more aggressive [WIP] Make Flax pt-flax equivalence test more aggressive Mar 3, 2022
@ydshieh ydshieh force-pushed the aggressive_pt_flax_equiv_test branch from eefded6 to 4603f5b Compare March 14, 2022 15:14
@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 14, 2022

[Updated Info.]

  • The (more aggressive) PT/TF test is merged to master
  • I fixed some bugs for this new PT/Flax test
    • There are 10 failures
      • 6 have very large difference between PT/Flax ( > 1.8) --> I will check if I can fix them easily
      • 4 have large difference (0.01 ~ 0.05) --> Need to verify if these are expected

@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 15, 2022

    * 6 have `very large` difference between PT/Flax ( > `1.8`) --> I will check if I can fix them easily
    * 4 have `large` difference (0.01 ~ 0.05) --> Need to verify if these are expected

Good news! Once #16167 and #16168 are merged, this more aggressive PT/Flax test will pass with 1e-5 on CPU. I will test with GPU later.

Sorry, but please ignore the above claim. The tests ran flax models on CPU because the GPU version of Jax/Flax were not installed.


Running on GPU with 1e-5 also passes! (run 3 times per model class)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 15, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh force-pushed the aggressive_pt_flax_equiv_test branch from 37b3720 to 27909ca Compare March 17, 2022 14:16
@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 17, 2022

Update:

After rebasing on a more recent commit on master (5a6b3ccd28a320bcde85190b0853ade385bd4158), this test with 1e-5 work fine!

I ran this new test on GPU (inside docker container that we use for CI GPU testing + with jax==0.3.0). The only errors I got is

  • Flax/Jax failed to determine best cudnn convolution algorithm for ...
    • Need to find out the cause eventually.

Think this PR is ready! @patil-suraj @patrickvonplaten

(After installing jax[cuda11_cudnn805] instead of jax[cuda11_cudnn82], the errors listed below no longer appear)


Error logs

FAILED tests/beit/test_modeling_flax_beit.py::FlaxBeitModelTest::test_equivalence_flax_to_pt - RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
FAILED tests/clip/test_modeling_flax_clip.py::FlaxCLIPVisionModelTest::test_equivalence_flax_to_pt - RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
FAILED tests/clip/test_modeling_flax_clip.py::FlaxCLIPModelTest::test_equivalence_flax_to_pt - RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
FAILED tests/vit/test_modeling_flax_vit.py::FlaxViTModelTest::test_equivalence_flax_to_pt - RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
FAILED tests/wav2vec2/test_modeling_flax_wav2vec2.py::FlaxWav2Vec2ModelTest::test_equivalence_flax_to_pt - RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:

@ydshieh ydshieh changed the title [WIP] Make Flax pt-flax equivalence test more aggressive Make Flax pt-flax equivalence test more aggressive Mar 17, 2022
@ydshieh ydshieh marked this pull request as ready for review March 17, 2022 15:44
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks lot for working on this!
Could you just rebase and push again ? This will make the CI green :)

@ydshieh ydshieh force-pushed the aggressive_pt_flax_equiv_test branch from 27909ca to 53605f9 Compare March 18, 2022 16:41
@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 18, 2022

It's all green! I will approve my own PR too: LGTM!

@patil-suraj patil-suraj merged commit d481b64 into huggingface:master Mar 18, 2022
@ydshieh ydshieh deleted the aggressive_pt_flax_equiv_test branch March 18, 2022 17:16
FrancescoSaverioZuppichini pushed a commit that referenced this pull request Mar 24, 2022
* Make test_equivalence_pt_to_flax more aggressive

* Make test_equivalence_flax_to_pt more aggressive

* don't use to_tuple

* clean-up

* fix missing test cases + testing on GPU

* fix conversion

* fix `ValueError: assignment destination is read-only`

* Add type checking

* commit to revert later

* Fix

* fix

* fix device

* better naming

* clean-up

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants