-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Flax pt-flax equivalence test more aggressive #15841
Make Flax pt-flax equivalence test more aggressive #15841
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
068a659
to
eefded6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this! Left a comment about tolerance value. I'm actually surprised that all tests are passing with 1e-5
. Did you test this on GPU ?
tests/test_modeling_flax_common.py
Outdated
@@ -160,12 +160,58 @@ def recursive_check(tuple_object, dict_object): | |||
dict_inputs = self._prepare_for_class(inputs_dict, model_class) | |||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) | |||
|
|||
def check_outputs(self, fxo, pto, model_class, names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) could we call this fx_outputs
and pt_outputs
instead of fxo
and pto
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, you and Sylvain told me the same thing :)
tests/test_modeling_flax_common.py
Outdated
Args: | ||
model_class: The class of the model that is currently testing. For example, ..., etc. | ||
Currently unused, but it could make debugging easier and faster. | ||
|
||
names: A string, or a list of strings. These specify what fxo/pto represent in the model outputs. | ||
Currently unused, but in the future, we could use this information to make the error message clearer | ||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) maybe also document fxo
and pto
fxo[pt_nans] = 0 | ||
|
||
max_diff = np.amax(np.abs(fxo - pto)) | ||
self.assertLessEqual(max_diff, 1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if 1e-5
will work for all models especially on TPU/GPU since JAX does some approximations on TPU so the output can diverge. cf #15754
What do you think @patrickvonplaten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check on GPU VM - currently I am doing this for PT/TF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think a precision of 1e-3
would be better
tests/test_modeling_flax_common.py
Outdated
# Pure convolutional models have no attention | ||
# TODO: use a better and general criteria | ||
if "FlaxConvNext" not in model_class.__name__: | ||
config.output_attentions = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConvNext
is not available in flax (yet!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I know :), I am just copying from PT/TF. Do you want me to remove it for now?)
@@ -160,12 +160,58 @@ def recursive_check(tuple_object, dict_object): | |||
dict_inputs = self._prepare_for_class(inputs_dict, model_class) | |||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) | |||
|
|||
def check_outputs(self, fxo, pto, model_class, names): | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need docstrings for test functions
tests/test_modeling_flax_common.py
Outdated
config.output_hidden_states = True | ||
# Pure convolutional models have no attention | ||
# TODO: use a better and general criteria | ||
if "FlaxConvNext" not in model_class.__name__: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This we can delete
Also very surprised that the tests are all passing. @ydshieh - can you double check that the tests are actually passing for most models. Some edge-case models that could be tested locally:
|
@patil-suraj @patrickvonplaten There are a few things to fix in this PR (same for the PT/TF) - some tests are just ignored by my mistakes. |
eefded6
to
4603f5b
Compare
[Updated Info.]
|
Sorry, but please ignore the above claim. The tests ran flax models on CPU because the GPU version of Jax/Flax were not installed. Running on |
The documentation is not available anymore as the PR was closed or merged. |
37b3720
to
27909ca
Compare
Update: After rebasing on a more recent commit on master ( I ran this new test on GPU (inside docker container that we use for CI GPU testing + with
Think this PR is ready! @patil-suraj @patrickvonplaten (After installing Error logs
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks lot for working on this!
Could you just rebase and push again ? This will make the CI green :)
27909ca
to
53605f9
Compare
It's all green! I will approve my own PR too: LGTM! |
* Make test_equivalence_pt_to_flax more aggressive * Make test_equivalence_flax_to_pt more aggressive * don't use to_tuple * clean-up * fix missing test cases + testing on GPU * fix conversion * fix `ValueError: assignment destination is read-only` * Add type checking * commit to revert later * Fix * fix * fix device * better naming * clean-up Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
What does this PR do?
Make Flax pt-flax equivalence test more aggressive. (Similar to #15839 for PT/TF).
It uses
output_hidden_states=True
andoutput_attentions=True
to test all output tensors (in a recursive way).Also, it lowers the tolerance from
4e-2
to1e-5
. (From the experience I gained in PT/TF test, if an error >1e-5
, I always found a bug to fix).(A bit) surprisingly, but very good news: unlike PT/TF, there is no PT/Flax inconsistency found by this more aggressive test! (@patil-suraj must have done a great job on flax models :-) )
Flax: @patil-suraj @patrickvonplaten