-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Address issues of top-k op #16670
base: main
Are you sure you want to change the base?
Address issues of top-k op #16670
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
6e27e82
to
d0be909
Compare
Please refer to #13235 (comment) for support of |
(1, 1, 2048, 64, 32), | ||
(1, 1, 32, 32768, 32), | ||
(1, 1, 8192, 64, 32), | ||
(1, 1, 64, 64, 2, 32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue mentions a k of 50. That should be tested as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes i tried k=50 but If i'm not mistaken, tt-budha (& thereby my code) only supports powers of 2.
Ill see what needs to be done to support non-powers-of-2 numbers.
One idea is : convert K to nearest power of 2 --> do LLK/compute kernel --> then reshape or slice output to desired shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be easiest to use the separation of ExecuteTopK and TopK and do what you suggested, except convert K to either 32 or 64.
ExecuteTopK:
- invokes TopK with k either 32 or 64
- then reshape or slice output to desired shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is correct that the TopK algorithm only supports powers of 2 for K. Any non-power-2 values need to be rounded UP to the nearest supported K value, and then you can truncate the output if needed. Rounding down won't work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@asandhupatlaTT I see only 32 and 64 in tests/ttnn/unit_tests/operations/test_topk.py
What happens when you try 2, 4, 8, and 16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bbradelTT since my kernel is. similar to tt-budha, it should work. but ill few test cases in next patch
@@ -34,9 +34,9 @@ void kernel_main() { | |||
uint32_t final_indices_cb_addr = get_write_ptr(final_indices_cb_index); | |||
|
|||
uint64_t noc_final_addr_values = | |||
get_noc_addr(noc_final_x, noc_final_y, final_values_cb_addr) + start_wt * tile_bytes_values; | |||
get_noc_addr(noc_final_x, noc_final_y, final_values_cb_addr) + start_wt * tile_bytes_values * Kt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to clean up how Kt is defined as well in the program factory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry i dont understand this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp Kt is defined by
uint32_t Kt = k % TILE_WIDTH == 0 ? k / TILE_WIDTH : k / TILE_WIDTH + 1;
That indicates that k could be something other than a multiple of TILE_WIDTH (e.g. 1, 2, 3, etc.). If that is not the case, it would be good to at least put a comment, and possibly change the code to just
uint32_t Kt = k / TILE_WIDTH;
cb_wait_front(input_transposed_cb_index, Wt); | ||
cb_wait_front(index_transposed_cb_index, Wt); | ||
|
||
while (idx < num_k_sequences) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sticking with tt-budha code. ok ill convert while to for loop
|
||
cb_push_back(input_transposed_cb_index, Wt); | ||
cb_push_back(index_transposed_cb_index, Wt); | ||
// print_all_tiles(input_transposed_cb_index, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should delete commented out code unless have an explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats a good place to print in-case people wanna debug this in future.
I see people adding such commented lines in other kernels (for example:
Line 46 in d0b0f9b
// DPRINT << cb_id_intermed0 << " "<< cb_id_intermed1 << " " <<intermed0_addr << " " << intermed1_addr <<" " << |
But i can remove it in next patch
release_dst(); | ||
direction = !direction; | ||
|
||
while (idx < num_k_sequences) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use for loops where possible. while loops should only be used when for loops would be too awkward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sticking with tt-budha code. ok ill convert while to for loop
|
||
int end_phase = (K <= 64) ? logk - 1 : 5; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need the elvis operator, and this variable? if K == 64, isn't logk - 1 equal to 5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sticking with tt-budha code.
for K > 64, its capped to 5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this code be able to deal with K larger than 64? E.g. 128 or 256?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats the goal. But Radomir said the implementation is not there yet. So 64 is max value that we can support as of now.
@rdjogoTT lemme know if i missed anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K=64 was the max value we have tested in the past, however the LLKs should be able to support K>64 as well. It should only require compute kernel-level changes, but it will be complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdjogoTT If K=128 was supported, would this end phase still be 5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, since we load 2 tiles into Dest at a time the max subsequence length we can sort is 64. 128 will then require some additional steps to get to 128
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks. In that case, we can leave this as is. @asandhupatlaTT please add a comment about only supporting up to 64.
2ffbb3d
to
f39f716
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
22fe8fd
to
9d4eef9
Compare
Ticket
Link to Github Issue
Problem description
few flags & input combinations (k, largest, sorted) are not supported
What's changed
Change compute kernel to support those flags. Adding code to support k=64 from tt-metal side
Checklist