Optimize the index_select operation for dim=0 #1113
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
The index_select operation is not well optimized in PyTorch, especially the
dim=0 case. It is shown to be one of the main bottlenecks in one of the
models.
This patch optimizes the index_select operation as well as its backward
counterpart (i.e., index_add_select) for the dim=0 case.
Optimizations in index_select:
performance but easier for adjusting UNROLL_FACTOR)
Optimizations in index_add_select:
buffer
are launched but return right away because another thread block already
processes the index that they get from the sorted indices list)
same performance but could be useful for the other large D cases)
flags for informing the operation to infer unique indices and the
number of unique indices from the consecutive indices range.
this property, we are able to infer unique indices and the number of
unique indices from the consecutive indices range. In the backward
op, since we already know the unique indices and the number of
unique indices, we can skip the unique operation. The performance
improvement are two folds: (1) no host-device synchronization
because of the resize op in unique, and (2) the additional operation
for computing the frequency of each index is lighter weight than the
unique operation.
Reviewed By: jianyuh, mjanderson09
Differential Revision: D35920450