-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix missing output_attentions in PT/Flax equivalence test #16271
Fix missing output_attentions in PT/Flax equivalence test #16271
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
3d07016
to
cb6459b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot!
tests/test_modeling_flax_common.py
Outdated
@@ -178,6 +179,12 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names): | |||
Currently unused, but in the future, we could use this information to make the error message clearer | |||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. | |||
""" | |||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, | |||
# an effort was done to return `attention_probs` (yet to be verified). | |||
if type(names) == str and names.startswith("attentions"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit too hacky for me here. Can't we just overwrite the test in test_modeling_flax_big_bird.py
?
tests/test_modeling_flax_common.py
Outdated
@@ -274,7 +281,8 @@ def test_equivalence_flax_to_pt(self): | |||
|
|||
# Output all for aggressive testing | |||
config.output_hidden_states = True | |||
# Pure convolutional models have no attention | |||
if self.has_attentions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't like this too much here either. Can't we check if there is a output_attentions
in the signature of the forward function and if that's the case then we set config.output_attentions=True
? This way we have 1 dependency less
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has_attentions
attribute was introduced in ModelTesterMixin
(#15909) (and then in TFModelTesterMixin
by me #16259).
Think it would be good to have the same approach for testing across the 3 frameworks. Let me know if you still prefer the other approach(es).
cc @NielsRogge @sgugger for further comments if any.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's use existing attributes and make the three testers consistent with each other.
Co-authored-by: Suraj Patil <surajp815@gmail.com>
ede2bea
to
151860f
Compare
@@ -314,6 +315,7 @@ def test_equivalence_flax_to_pt(self): | |||
|
|||
# send pytorch model to the correct device | |||
pt_model_loaded.to(torch_device) | |||
pt_model_loaded.eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't forget to set to eval
for re-loaded pt model
Think this (quite small) PR is ready. Nothing particular but adding the missing Will merge it today. |
@@ -168,6 +169,7 @@ def recursive_check(tuple_object, dict_object): | |||
dict_inputs = self._prepare_for_class(inputs_dict, model_class) | |||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) | |||
|
|||
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention is only to add this information, not mean to work with the current version of make fix-copies
.
@sgugger Are you OK with this comment? Otherwise I can just remove it.
What does this PR do?
In a previous PR #15841,
output_attentions
was not set (I accidentally removed the whole block containing it).This PR sets
output_attentions
to make the test more thorough.The test still runs successfully with
1e-5
on both CPU/GPU. However, see the 2nd points in the remarks below.It also adds
has_attentions
attribute toFlaxModelTesterMixin
(as done in PyTorch'sModelTesterMixin
).Remarks:
has_attentions
in some existing methods (to make sure the attentions are only tested ifhas_attentions
isTrue
), see [Tests] Add attentions_option to ModelTesterMixin #15909test_equivalence_pt_to_flax
andtest_equivalence_flax_to_pt
.FlaxGPTJ
andFlaxXGLM
, which will fail with1e-5
. I need to debug them.