Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix longformer onnx broken export #20292

Merged
merged 8 commits into from
Nov 22, 2022

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Nov 17, 2022

This PR fixes the ONNX export of longformer, that was silently broken for several cases:

  • the export registers padding_len > 0 as a constant equal to True, hence during inference in the dynamic case padding_len == 0, we would still go through the path padding_len > 0 that would then contain negative indexing making some ONNX nodes fail (gather). This PR fixes the negative indexes.
  • the export registers hidden_states.size(1) == window_overlap * 2: as a constant equal True during the export, hence using the converted ONNX model was failing when the input_ids length was strictly greater than attention_window (case where the else path should be taken). This PR removes the path hidden_states.size(1) == window_overlap * 2:, since the other path can handle this case as well.

Had to run make fix-copies than modified led model as well.

@michaelbenayoun @lewisbails Where should I add tests for this? Optimum?

Before submitting

  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 17, 2022

The documentation is not available anymore as the PR was closed or merged.

@michaelbenayoun
Copy link
Member

Yes, I would add the support for Longformer in optimum, and the tests there as well.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (when the tests pass)

@@ -221,7 +221,10 @@ def generate_dummy_inputs(
)
import torch

# for some reason, replacing this code by inputs["global_attention_mask"] = torch.randint(2, inputs["input_ids"].shape, dtype=torch.int64)
# makes the export fail randomly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might try to attend to some forbidden positions?

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing this sneaky bug @fxmarty !

The modeling changes look good to me, but gently pinging @ydshieh for his opinion too.

@ydshieh ydshieh self-requested a review November 21, 2022 17:53
@@ -1294,7 +1291,7 @@ def forward(

is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
is_global_attn = torch.any(is_index_global_attn).item() # the ONNX export should record is_global_attn == True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to review in more details later, but a question here:

Why we need to :

  • use torch.any instead of the tensor's any() method
  • remove flatten

Are these just some simplification rather than necessarity?

Furthermore, could you explain a bit to me (ONNX newbie) what it means

# Record `is_global_attn == True` to enable ONNX export

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it's probably unnecessary. I'll fix.

The comment means that in a normal case scenario, the control flows should be done having is_global_attn == True. This means that the example provided to torch.onnx.export should have at least at least one id of global_attention_mask non-zero.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I understand the comment better now.

One more question however, if the case is_global_attn = True is recorded and used few lines below (note here is an assignment, not ==) during export time, would we have trouble during inference time if we feed an input without any positive global attention mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. IMO the ONNX export is lacking in this sense: since we record only a single path, there may well be cases where the exported ONNX is invalid / will throw weird errors), while it is valid in other cases.

I was thinking to have some kind of support (in Optimum exporters) for an external file to the .onnx that would specify which cases are supported with the exported onnx.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @fxmarty!

@fxmarty fxmarty force-pushed the fix-longformer-onnx-controlflow branch from 4f964e5 to 8028aea Compare November 22, 2022 14:18
@lewtun
Copy link
Member

lewtun commented Nov 22, 2022

Thanks for iterating @fxmarty !

Since @ydshieh has also approved, gently pinging @sgugger for final approval :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@sgugger sgugger merged commit 3d0c0ae into huggingface:main Nov 22, 2022
@adithya1111
Copy link

adithya1111 commented Nov 22, 2022

Hello @sgugger @lewtun @ ydshieh @fxmarty : I am stuck with the same issue. Thanks for the fix. This may be a noob question. But wanted to check when would this change be reflected in the PyPi package ? Is it during the next release ? If so do we know when would that be happening ?

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 22, 2022

@adithya1111 Somewhere in the next week if I remember correctly.

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 22, 2022

Hello @adithya1111 yes this will be in the next release, feel free to try the main branch meanwhile. In any case, I would advise you to be very careful with the exported ONNX model, and to check that the outputs are on par with PyTorch for your target sequence length. You can possibly edit the ONNX export code if you want to use the exported ONNX with a different sequence length, as explained in my messages above.

For reference: huggingface/optimum#503

@adithya1111
Copy link

adithya1111 commented Nov 22, 2022

Thanks a lot for the comments @fxmarty @ydshieh . Another question. I used the default parameters when training. So would that mean the Global Attention Mask is None ? I see that we are now setting the global_attention_mask[:, ::2] = 1 . I assume here we are making every second token global. Could this lead to a discrepancy ?

My original models results are logits=tensor([[[ 0.6209, 0.0719, 0.1107, -0.5316], [ 3.0321, -0.2787, -0.6460, -2.5359], [ 2.6904, 0.1169, -0.7495, -2.8346], [ 0.6474, 0.0761, 0.1041, -0.5438]]]

And my ONNX converted predictions are [array([[[ 0.49600145, 0.08062335, 0.12902021, -0.4010917 ], [ 3.0400352 , -0.34643874, -0.6276542 , -2.444679 ], [ 2.158992 , 0.02124629, -0.5462518 , -2.094074 ], [ 0.6290194 , 0.06919068, 0.10753635, -0.5197539 ]]], dtype=float32)]

Its close but there are some discrepancies . PFB my config file for my model

{ "_name_or_path": "/opt/ml/input/data/model-base", "architectures": [ "LongformerForTokenClassification" ], "attention_mode": "longformer", "attention_probs_dropout_prob": 0.1, "attention_window": [ 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512 ], "bos_token_id": 0, "eos_token_id": 2, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "id2label": { "0": "LABEL_0", "1": "LABEL_1", "2": "LABEL_2", "3": "LABEL_3" }, "ignore_attention_mask": false, "initializer_range": 0.02, "intermediate_size": 3072, "label2id": { "LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2, "LABEL_3": 3 }, "layer_norm_eps": 1e-05, "max_position_embeddings": 4098, "model_type": "longformer", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 1, "position_embedding_type": "absolute", "sep_token_id": 2, "torch_dtype": "float32", "transformers_version": "4.9.1", "type_vocab_size": 1, "use_cache": true, "vocab_size": 50265 }

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 22, 2022

@adithya1111 Could you open an issue in https://github.com/huggingface/optimum/issues with a reproducible code? ONNX export through transformers.onnx will eventually depend on optimum.exporters so we can track the issue there.

Thanks!

@adithya1111
Copy link

Created a new issue. Thanks

@adithya1111
Copy link

@sgugger @fxmarty @ydshieh @lewtun : Does the latest release fix this issue ?

@fxmarty
Copy link
Contributor Author

fxmarty commented Dec 1, 2022

With the release you will have no error at the export & running the ONNX model. However, following the discussion above (see the closed comments), and as well in #20275 , huggingface/optimum#503 , huggingface/optimum#505 , you can expect to have non-meaningful output running the ONNX model with sensibly different sequence length than the example provided to torch.onnx.export during the conversion.

This is WIP to add options to customize more the export, refer to huggingface/optimum#522

mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
* fix controlflow for onnx export

* fix warning

* fix the case padding_len = 0, explicit the recorded control flows

* style

* style

* fix bug

* fix copy

* nits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants