Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Longformer] Major Refactor #5219

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jun 23, 2020

Longformer Refactor

This PR does a major refactoring of Longformer. Mainly, the Roberta abstraction is removed and compositionally is chosen instead. This has the following advantages:

  • It's easier now to implement a cross_attention_layer
  • The code is more readable and the logic stays in this file only
  • A bug was corrected regarding the attention mask. @ibeltagy - maybe you can check this as well. Previously, if no attention_mask was inserted, the padding function that became before super.forward() in LongformerModel was not used, but if instead an attention_mask = torch.tensor([1, ..., 1]) (attend to all tokens was passed, the padding function was applied and could lead to different outputs as when no attention_mask is passed. This should not be the case. model(input_ids) and model(input_ids, attention_mask=torch.ones(input_ids.shape)) should always yield the same result. Removing the super.forward() abstraction makes the code much cleaner here so that a attention_mask = torch.ones(input_ids.shape) can be calculated before calling the longformer encoder. IMPORTANT Since in almost all tasks longformer somehow passes either a global_attention_mask or attention_mask to LongformerModel, this bug did not really become visible before.
  • We don't have to "inject" a self-attention layer into another model anymore, which I did not like very much.
  • Unnecessary code can be removed (head_mask, prev cross-attention layer inputs that do not work yet), ...

Additionally:

  • Variable names are made more explicit and dead code (If statements that would have never occurred) was removed and code is simplified.
  • The forward function of the self-attention layer is broken up into multiple helper functions. The advantage here is that quite some memory should be saved before attention_probs go out of scope after they are not used anymore and thus the memory bottleneck should be reduced.
  • All longformer models are added to the tests (@sgugger) and a couple more tests are added.

Next step is to add cross attention layers to longformer.

Review

I made sure that besides the bug with attention_mask = None vs attention_mask = torch.ones(...) all outputs stay the same.
Would be great if @thomwolf @LysandreJik @sgugger @sshleifer @ibeltagy can do a quick review.

@patrickvonplaten patrickvonplaten changed the title [WIP - Don't merge!] Refactor longformer [WIP - Don't merge] Refactor longformer Jun 23, 2020
@codecov
Copy link

codecov bot commented Jun 23, 2020

Codecov Report

Merging #5219 into master will decrease coverage by 0.81%.
The diff coverage is 92.60%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5219      +/-   ##
==========================================
- Coverage   77.85%   77.04%   -0.82%     
==========================================
  Files         138      138              
  Lines       24314    24409      +95     
==========================================
- Hits        18930    18806     -124     
- Misses       5384     5603     +219     
Impacted Files Coverage Δ
src/transformers/modeling_longformer.py 91.66% <92.60%> (-1.45%) ⬇️
src/transformers/modeling_tf_mobilebert.py 23.62% <0.00%> (-73.11%) ⬇️
src/transformers/modeling_tf_bert.py 73.37% <0.00%> (-25.00%) ⬇️
src/transformers/modeling_tf_utils.py 87.39% <0.00%> (-0.15%) ⬇️
src/transformers/modeling_openai.py 81.09% <0.00%> (+1.37%) ⬆️
src/transformers/modeling_tf_distilbert.py 98.76% <0.00%> (+32.51%) ⬆️
src/transformers/modeling_tf_openai.py 94.98% <0.00%> (+74.19%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9a473f1...90d2aa6. Read the comment docs.

@patrickvonplaten patrickvonplaten changed the title [WIP - Don't merge] Refactor longformer Refactor longformer Jun 30, 2020
@@ -812,7 +812,7 @@ def test_multigpu_data_parallel_forward(self):
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**self._prepare_for_class(inputs_dict, model_class))
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger added the prepare function here because otherwise longformer tests were failing for multiple choice.

@@ -115,6 +115,18 @@ def prepare_config_and_inputs(self):
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])

def create_and_check_attention_mask_determinism(
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug previously in that running the model without attention_mask and with attention_mask = torch.tensor([1, ..., 1]) did not give the same output.

self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)

def test_longformer_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test for global attention mask was missing before


@slow
def test_inference_no_head_long(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one slow test that a 16GB computer can run

expected_loss = torch.tensor(0.0620, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device)
expected_loss = torch.tensor(0.0074, device=torch_device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the model calculated the wrong attention_mask when no attention_mask was given -> update the values here.

x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better naming

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the better naming

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why total_num_heads? Is there another num_heads?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sshleifer, total_num_heads = num_heads * batch_size


if attention_mask is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify forward call and move a lot of code to helper functions

@patrickvonplaten patrickvonplaten requested a review from sgugger June 30, 2020 08:50
@patrickvonplaten patrickvonplaten changed the title Refactor longformer [Longformer] Major Refactor Jun 30, 2020
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great changes overall, I love that the functions/variables names became more explicit. The code looks overall closer to the library's philosophy, which is a welcome change!

x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the better naming

Comment on lines -280 to +331
# TODO: make tests pass for those models
# LongformerForSequenceClassification,
# LongformerForQuestionAnswering,
# LongformerForTokenClassification,
# LongformerForMultipleChoice,
LongformerForSequenceClassification,
LongformerForQuestionAnswering,
LongformerForTokenClassification,
LongformerForMultipleChoice,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool diff

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! And I love that the model get properly tested now :-)

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Halfway, sorry for all the nits. Feel free to ignore them! This is really cool!

src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why total_num_heads? Is there another num_heads?

src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@ibeltagy ibeltagy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. It must have been a lot of work, thanks, @patrickvonplaten. I checked the attention_mask bug you mentioned and your fix is working well, thanks for addressing it. I also left a few comments, mostly nits, so feel free to address or ignore as you see fit.
Thanks.

src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sshleifer, total_num_heads = num_heads * batch_size

src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
src/transformers/modeling_longformer.py Outdated Show resolved Hide resolved
src/transformers/modeling_longformer.py Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor Author

@sshleifer and @ibeltagy - thanks a lot for your comments -> cleaned up the comments and some function naming.

All slow and normal tests pass on GPU => good to merge.

@patrickvonplaten patrickvonplaten merged commit d697b6c into huggingface:master Jul 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants