-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix longformer onnx broken export #20292
Fix longformer onnx broken export #20292
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Yes, I would add the support for Longformer in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (when the tests pass)
@@ -221,7 +221,10 @@ def generate_dummy_inputs( | |||
) | |||
import torch | |||
|
|||
# for some reason, replacing this code by inputs["global_attention_mask"] = torch.randint(2, inputs["input_ids"].shape, dtype=torch.int64) | |||
# makes the export fail randomly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might try to attend to some forbidden positions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1294,7 +1291,7 @@ def forward( | |||
|
|||
is_index_masked = attention_mask < 0 | |||
is_index_global_attn = attention_mask > 0 | |||
is_global_attn = is_index_global_attn.flatten().any().item() | |||
is_global_attn = torch.any(is_index_global_attn).item() # the ONNX export should record is_global_attn == True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will try to review in more details later, but a question here:
Why we need to :
- use
torch.any
instead of the tensor'sany()
method - remove
flatten
Are these just some simplification rather than necessarity?
Furthermore, could you explain a bit to me (ONNX newbie) what it means
# Record `is_global_attn == True` to enable ONNX export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, it's probably unnecessary. I'll fix.
The comment means that in a normal case scenario, the control flows should be done having is_global_attn == True
. This means that the example provided to torch.onnx.export
should have at least at least one id of global_attention_mask
non-zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I understand the comment better now.
One more question however, if the case is_global_attn = True
is recorded and used few lines below (note here is an assignment, not ==
) during export time, would we have trouble during inference time if we feed an input without any positive global attention mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly. IMO the ONNX export is lacking in this sense: since we record only a single path, there may well be cases where the exported ONNX is invalid / will throw weird errors), while it is valid in other cases.
I was thinking to have some kind of support (in Optimum exporters) for an external file to the .onnx
that would specify which cases are supported with the exported onnx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @fxmarty!
4f964e5
to
8028aea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
@adithya1111 Somewhere in the next week if I remember correctly. |
Hello @adithya1111 yes this will be in the next release, feel free to try the main branch meanwhile. In any case, I would advise you to be very careful with the exported ONNX model, and to check that the outputs are on par with PyTorch for your target sequence length. You can possibly edit the ONNX export code if you want to use the exported ONNX with a different sequence length, as explained in my messages above. For reference: huggingface/optimum#503 |
Thanks a lot for the comments @fxmarty @ydshieh . Another question. I used the default parameters when training. So would that mean the Global Attention Mask is None ? I see that we are now setting the global_attention_mask[:, ::2] = 1 . I assume here we are making every second token global. Could this lead to a discrepancy ? My original models results are And my ONNX converted predictions are Its close but there are some discrepancies . PFB my config file for my model
|
@adithya1111 Could you open an issue in https://github.com/huggingface/optimum/issues with a reproducible code? ONNX export through Thanks! |
Created a new issue. Thanks |
With the release you will have no error at the export & running the ONNX model. However, following the discussion above (see the closed comments), and as well in #20275 , huggingface/optimum#503 , huggingface/optimum#505 , you can expect to have non-meaningful output running the ONNX model with sensibly different sequence length than the example provided to This is WIP to add options to customize more the export, refer to huggingface/optimum#522 |
* fix controlflow for onnx export * fix warning * fix the case padding_len = 0, explicit the recorded control flows * style * style * fix bug * fix copy * nits
This PR fixes the ONNX export of longformer, that was silently broken for several cases:
padding_len > 0
as a constant equal toTrue
, hence during inference in the dynamic casepadding_len == 0
, we would still go through the pathpadding_len > 0
that would then contain negative indexing making some ONNX nodes fail (gather). This PR fixes the negative indexes.hidden_states.size(1) == window_overlap * 2:
as a constant equalTrue
during the export, hence using the converted ONNX model was failing when theinput_ids
length was strictly greater thanattention_window
(case where theelse
path should be taken). This PR removes the pathhidden_states.size(1) == window_overlap * 2:
, since the other path can handle this case as well.Had to run
make fix-copies
than modified led model as well.@michaelbenayoun @lewisbails Where should I add tests for this? Optimum?
Before submitting