-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MMA_FROM_SMEM_IT_RES] Change how residuals are handled #393
Conversation
Codecov ReportBase: 91.14% // Head: 91.23% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## it1 #393 +/- ##
==========================================
+ Coverage 91.14% 91.23% +0.09%
==========================================
Files 75 75
Lines 4358 4346 -12
==========================================
- Hits 3972 3965 -7
+ Misses 386 381 -5
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
ffb63a7
to
a5aeb22
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome Daniel!
@@ -337,6 +339,11 @@ class PredicatedTileIterator< | |||
address_iterator_.clear_mask(enable); | |||
} | |||
|
|||
CUTLASS_HOST_DEVICE | |||
void set_residual_tile(bool enable) { | |||
address_iterator_.set_residual_tile(enable); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a pity that address_iterator
is a private member, otherwise we could have just inherited from the base class and added this method to all the classes in this file.
params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); | ||
pointer_ += Shape::kStrided * tile_offset.strided(); | ||
} | ||
if (!Gather) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here, this is the only thing that we changed (in addition to the set_residual_tile
) but we need to copy the whole file. :-/
Changes how we deal with residuals for the right-hand operand of MM from shared memory. This is used on the fw and bw pass, and enables to support arbitrary sequence lengths.
P100/V100 bw (causal)
P100/V100 bw
A100 bw
P100/V100 fw (causal)
P100/V100 fw
A100 fw
Stacked PR Chain: MMA_FROM_SMEM_IT_RES