Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address issues of top-k op #16670

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

Address issues of top-k op #16670

wants to merge 16 commits into from

Conversation

asandhupatlaTT
Copy link
Contributor

@asandhupatlaTT asandhupatlaTT commented Jan 13, 2025

Ticket

Link to Github Issue

Problem description

few flags & input combinations (k, largest, sorted) are not supported

What's changed

Change compute kernel to support those flags. Adding code to support k=64 from tt-metal side

Checklist

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp Outdated Show resolved Hide resolved
@asandhupatlaTT
Copy link
Contributor Author

Please refer to #13235 (comment) for support of sorted=False

ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp Outdated Show resolved Hide resolved
(1, 1, 2048, 64, 32),
(1, 1, 32, 32768, 32),
(1, 1, 8192, 64, 32),
(1, 1, 64, 64, 2, 32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue mentions a k of 50. That should be tested as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i tried k=50 but If i'm not mistaken, tt-budha (& thereby my code) only supports powers of 2.
Ill see what needs to be done to support non-powers-of-2 numbers.

One idea is : convert K to nearest power of 2 --> do LLK/compute kernel --> then reshape or slice output to desired shape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be easiest to use the separation of ExecuteTopK and TopK and do what you suggested, except convert K to either 32 or 64.

ExecuteTopK:

  • invokes TopK with k either 32 or 64
  • then reshape or slice output to desired shape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is correct that the TopK algorithm only supports powers of 2 for K. Any non-power-2 values need to be rounded UP to the nearest supported K value, and then you can truncate the output if needed. Rounding down won't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asandhupatlaTT I see only 32 and 64 in tests/ttnn/unit_tests/operations/test_topk.py
What happens when you try 2, 4, 8, and 16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bbradelTT since my kernel is. similar to tt-budha, it should work. but ill few test cases in next patch

@@ -34,9 +34,9 @@ void kernel_main() {
uint32_t final_indices_cb_addr = get_write_ptr(final_indices_cb_index);

uint64_t noc_final_addr_values =
get_noc_addr(noc_final_x, noc_final_y, final_values_cb_addr) + start_wt * tile_bytes_values;
get_noc_addr(noc_final_x, noc_final_y, final_values_cb_addr) + start_wt * tile_bytes_values * Kt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to clean up how Kt is defined as well in the program factory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry i dont understand this comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp Kt is defined by

uint32_t Kt = k % TILE_WIDTH == 0 ? k / TILE_WIDTH : k / TILE_WIDTH + 1;

That indicates that k could be something other than a multiple of TILE_WIDTH (e.g. 1, 2, 3, etc.). If that is not the case, it would be good to at least put a comment, and possibly change the code to just

uint32_t Kt = k / TILE_WIDTH;

cb_wait_front(input_transposed_cb_index, Wt);
cb_wait_front(index_transposed_cb_index, Wt);

while (idx < num_k_sequences) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a for loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sticking with tt-budha code. ok ill convert while to for loop


cb_push_back(input_transposed_cb_index, Wt);
cb_push_back(index_transposed_cb_index, Wt);
// print_all_tiles(input_transposed_cb_index, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should delete commented out code unless have an explanation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats a good place to print in-case people wanna debug this in future.

I see people adding such commented lines in other kernels (for example:

// DPRINT << cb_id_intermed0 << " "<< cb_id_intermed1 << " " <<intermed0_addr << " " << intermed1_addr <<" " <<
)

But i can remove it in next patch

release_dst();
direction = !direction;

while (idx < num_k_sequences) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use for loops where possible. while loops should only be used when for loops would be too awkward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sticking with tt-budha code. ok ill convert while to for loop

ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp Outdated Show resolved Hide resolved

int end_phase = (K <= 64) ? logk - 1 : 5;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need the elvis operator, and this variable? if K == 64, isn't logk - 1 equal to 5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sticking with tt-budha code.
for K > 64, its capped to 5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this code be able to deal with K larger than 64? E.g. 128 or 256?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats the goal. But Radomir said the implementation is not there yet. So 64 is max value that we can support as of now.
@rdjogoTT lemme know if i missed anything

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K=64 was the max value we have tested in the past, however the LLKs should be able to support K>64 as well. It should only require compute kernel-level changes, but it will be complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdjogoTT If K=128 was supported, would this end phase still be 5?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since we load 2 tiles into Dest at a time the max subsequence length we can sort is 64. 128 will then require some additional steps to get to 128

Copy link
Contributor

@bbradelTT bbradelTT Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks. In that case, we can leave this as is. @asandhupatlaTT please add a comment about only supporting up to 64.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp Outdated Show resolved Hide resolved
ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants