-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Longformer] Major Refactor #5219
[Longformer] Major Refactor #5219
Conversation
Codecov Report
@@ Coverage Diff @@
## master #5219 +/- ##
==========================================
- Coverage 77.85% 77.04% -0.82%
==========================================
Files 138 138
Lines 24314 24409 +95
==========================================
- Hits 18930 18806 -124
- Misses 5384 5603 +219
Continue to review full report at Codecov.
|
6e200ea
to
8185265
Compare
8185265
to
149b057
Compare
@@ -812,7 +812,7 @@ def test_multigpu_data_parallel_forward(self): | |||
# Wrap model in nn.DataParallel | |||
model = torch.nn.DataParallel(model) | |||
with torch.no_grad(): | |||
_ = model(**inputs_dict) | |||
_ = model(**self._prepare_for_class(inputs_dict, model_class)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger added the prepare function here because otherwise longformer tests were failing for multiple choice.
@@ -115,6 +115,18 @@ def prepare_config_and_inputs(self): | |||
def check_loss_output(self, result): | |||
self.parent.assertListEqual(list(result["loss"].size()), []) | |||
|
|||
def create_and_check_attention_mask_determinism( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a bug previously in that running the model without attention_mask
and with attention_mask = torch.tensor([1, ..., 1])
did not give the same output.
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) | ||
|
||
def test_longformer_model_global_attention_mask(self): | ||
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A test for global attention mask was missing before
|
||
@slow | ||
def test_inference_no_head_long(self): | ||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add one slow test that a 16GB computer can run
expected_loss = torch.tensor(0.0620, device=torch_device) | ||
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device) | ||
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device) | ||
expected_loss = torch.tensor(0.0074, device=torch_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously the model calculated the wrong attention_mask
when no attention_mask
was given -> update the values here.
x = x.view(B, C, M, M + L) # B x C, M x L+M | ||
x = x[:, :, :, :-1] | ||
return x | ||
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better naming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the better naming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why total_num_heads? Is there another num_heads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sshleifer, total_num_heads = num_heads * batch_size
|
||
if attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify forward call and move a lot of code to helper functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great changes overall, I love that the functions/variables names became more explicit. The code looks overall closer to the library's philosophy, which is a welcome change!
x = x.view(B, C, M, M + L) # B x C, M x L+M | ||
x = x[:, :, :, :-1] | ||
return x | ||
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the better naming
# TODO: make tests pass for those models | ||
# LongformerForSequenceClassification, | ||
# LongformerForQuestionAnswering, | ||
# LongformerForTokenClassification, | ||
# LongformerForMultipleChoice, | ||
LongformerForSequenceClassification, | ||
LongformerForQuestionAnswering, | ||
LongformerForTokenClassification, | ||
LongformerForMultipleChoice, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! And I love that the model get properly tested now :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Halfway, sorry for all the nits. Feel free to ignore them! This is really cool!
x = x.view(B, C, M, M + L) # B x C, M x L+M | ||
x = x[:, :, :, :-1] | ||
return x | ||
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why total_num_heads? Is there another num_heads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. It must have been a lot of work, thanks, @patrickvonplaten. I checked the attention_mask
bug you mentioned and your fix is working well, thanks for addressing it. I also left a few comments, mostly nits, so feel free to address or ignore as you see fit.
Thanks.
x = x.view(B, C, M, M + L) # B x C, M x L+M | ||
x = x[:, :, :, :-1] | ||
return x | ||
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sshleifer, total_num_heads = num_heads * batch_size
@sshleifer and @ibeltagy - thanks a lot for your comments -> cleaned up the comments and some function naming. All slow and normal tests pass on GPU => good to merge. |
Longformer Refactor
This PR does a major refactoring of Longformer. Mainly, the Roberta abstraction is removed and compositionally is chosen instead. This has the following advantages:
cross_attention_layer
attention_mask
was inserted, the padding function that became beforesuper.forward()
inLongformerModel
was not used, but if instead anattention_mask = torch.tensor([1, ..., 1])
(attend to all tokens was passed, the padding function was applied and could lead to different outputs as when noattention_mask
is passed. This should not be the case.model(input_ids)
andmodel(input_ids, attention_mask=torch.ones(input_ids.shape))
should always yield the same result. Removing thesuper.forward()
abstraction makes the code much cleaner here so that aattention_mask = torch.ones(input_ids.shape)
can be calculated before calling the longformer encoder. IMPORTANT Since in almost all tasks longformer somehow passes either aglobal_attention_mask
orattention_mask
toLongformerModel
, this bug did not really become visible before.self-attention layer
into another model anymore, which I did not like very much.Additionally:
attention_probs
go out of scope after they are not used anymore and thus the memory bottleneck should be reduced.Next step is to add cross attention layers to longformer.
Review
I made sure that besides the bug with
attention_mask = None
vsattention_mask = torch.ones(...)
all outputs stay the same.Would be great if @thomwolf @LysandreJik @sgugger @sshleifer @ibeltagy can do a quick review.