Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address issues of top-k op #16670

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,105 @@ def test_mean_2d_tensor_dims(device, h, w, dim, keepdim):

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("dim1", [1])
@pytest.mark.parametrize("dim2", [1])
@pytest.mark.parametrize("dim3", [8])
@pytest.mark.parametrize("dim4", [256])
@pytest.mark.parametrize("dim5", [64])
# @pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) transpose cannot handle N-D tensor for all dims
@pytest.mark.parametrize("dim", [3, 4])
@pytest.mark.parametrize("k", [17, 32, 64])
@pytest.mark.parametrize("largest", [True])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_5d_topk(device, dim1, dim2, dim3, dim4, dim5, dim, k, largest, dtype):
torch.manual_seed(2005)
shape = [dim1, dim2, dim3, dim4, dim5]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=True)

desired_shape = [dim1, dim2, dim3, dim4, dim5]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)

if dtype == ttnn.bfloat8_b:
pcc_values = 0.99
else:
pcc_values = 1.0

# pcc is not a good measure for the raw indices
# if index 49 and index 8 are tied, the order of the indices can be different
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)


@pytest.mark.parametrize("dim1", [1])
@pytest.mark.parametrize("dim2", [1])
@pytest.mark.parametrize("dim3", [8])
@pytest.mark.parametrize("dim4", [1])
@pytest.mark.parametrize("dim5", [128])
@pytest.mark.parametrize("dim6", [64])
# @pytest.mark.parametrize("dim", [0, 1, 2, 3, 4, 5]) transpose cannot handle N-D tensor for all dims
@pytest.mark.parametrize("dim", [4, 5])
@pytest.mark.parametrize("k", [50, 64])
@pytest.mark.parametrize("largest", [True])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_6d_topk(device, dim1, dim2, dim3, dim4, dim5, dim6, dim, k, largest, dtype):
torch.manual_seed(2005)
shape = [dim1, dim2, dim3, dim4, dim5, dim6]
torch_dtype = torch.bfloat
16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=True)

desired_shape = [dim1, dim2, dim3, dim4, dim5, dim6]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)

if dtype == ttnn.bfloat8_b:
pcc_values = 0.99
else:
pcc_values = 1.0

# pcc is not a good measure for the raw indices
# if index 49 and index 8 are tied, the order of the indices can be different
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)
56 changes: 40 additions & 16 deletions tests/ttnn/unit_tests/operations/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@
from models.utility_functions import skip_for_grayskull


def run_topk_test(N, C, H, W, k, dtype, device):
def run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device):
torch.manual_seed(2005)
shape = [N, C, H, W]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=True, sorted=True)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=sorted)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=True, sorted=True)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=sorted)

assert list(ttnn_topk_values.shape.with_tile_padding()) == [N, C, H, k]
assert list(ttnn_topk_indices.shape.with_tile_padding()) == [N, C, H, k]
desired_shape = [N, C, H, W]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)
Expand All @@ -39,10 +42,10 @@ def run_topk_test(N, C, H, W, k, dtype, device):
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top 32 elements, so the pcc will be worse as stability/rounding will cause more ties
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, -1, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=-1)
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
Expand All @@ -64,14 +67,35 @@ def run_topk_test(N, C, H, W, k, dtype, device):
],
)
@pytest.mark.parametrize(
"N, C, H, W, k,",
"N, C, H, W, dim, k",
(
(1, 1, 32, 64, 32),
(1, 1, 32, 8192, 32),
(1, 1, 2048, 64, 32),
(1, 1, 32, 32768, 32),
(1, 1, 8192, 64, 32),
(1, 1, 64, 64, 2, 32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue mentions a k of 50. That should be tested as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i tried k=50 but If i'm not mistaken, tt-budha (& thereby my code) only supports powers of 2.
Ill see what needs to be done to support non-powers-of-2 numbers.

One idea is : convert K to nearest power of 2 --> do LLK/compute kernel --> then reshape or slice output to desired shape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be easiest to use the separation of ExecuteTopK and TopK and do what you suggested, except convert K to either 32 or 64.

ExecuteTopK:

  • invokes TopK with k either 32 or 64
  • then reshape or slice output to desired shape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is correct that the TopK algorithm only supports powers of 2 for K. Any non-power-2 values need to be rounded UP to the nearest supported K value, and then you can truncate the output if needed. Rounding down won't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asandhupatlaTT I see only 32 and 64 in tests/ttnn/unit_tests/operations/test_topk.py
What happens when you try 2, 4, 8, and 16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bbradelTT since my kernel is. similar to tt-budha, it should work. but ill few test cases in next patch

(1, 1, 64, 64, 2, 64),
(1, 1, 32, 8192, 3, 50),
(1, 2048, 1, 64, 1, 32),
(1, 1, 32, 32768, 3, 17),
(128, 1, 1, 64, 0, 64),
),
)
def test_topk(N, C, H, W, k, dtype, device):
run_topk_test(N, C, H, W, k, dtype, device)
@pytest.mark.parametrize(
"sorted",
[
True,
# False, Please refer to https://github.com/tenstorrent/tt-metal/issues/13235#issuecomment-2601432673
],
)
@pytest.mark.parametrize(
"largest",
[
True,
# False, Waiting for Ata's patch to be merged
],
)
def test_topk(N, C, H, W, dim, k, dtype, sorted, largest, device):
if dim == 0 or dim == 1:
# As of now, when we try to get top-k for dim = 0 or 1, we get following error from transpose_op.cpp's validate():
# input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32
# this is because, transpose.cpp always typecasts bf8 to bf16
# and when dim = 0 or 1, transpose converts it into TransposeOpDim::HC & this dim doesnt support bf16 or fp32
pytest.skip()
run_topk_test(N, C, H, W, k, dtype, dim, sorted, largest, device)
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/prod/prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/halo.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/sliding_window.cpp
Expand Down
Loading
Loading