[Longformer] fix longformer slow-down #5811
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A drastic slow-down of Longformer after the PR: #5219 was deteced by @HHousen (thanks a lot!) here: #4406 (comment) .
After digging a bit into the code it can be seen that the line:
is_global_attn = any(is_index_global_attn.flatten())
is the reason of the drastic slow-down.Running the following benchmark on master:
yields the following result:
Now on this branch the results are as follows:
So this simple line is responsible for a slow-down of factor 4.
Moral of the story: Never use
any()
on a PyTorch tensor => always use tensor.any(). I wonder if this is actually a known problem of PyTorch/Python. It might be a good to check our code if we have more statements like this.Another lesson for me is that one should always run the benchmarking before and after doing such a big refactoring as in #5219 .
It's very simple to run the benchmark script for a model and takes usually only a couple of seconds. Ideally we should have performance regression tests to automatically detect such slow downs.
Pinging @thomwolf @mfuntowicz @sshleifer @LysandreJik @ibeltagy - think this is quite interesting.