Skip to content

Commit

Permalink
support all dims
Browse files Browse the repository at this point in the history
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
  • Loading branch information
asandhupatlaTT committed Jan 18, 2025
1 parent 1f1ef8d commit d0be909
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 35 deletions.
36 changes: 16 additions & 20 deletions tests/ttnn/unit_tests/operations/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@
from models.utility_functions import skip_for_grayskull


def run_topk_test(N, C, H, W, k, dtype, device):
def run_topk_test(N, C, H, W, k, dtype, dim, device):
torch.manual_seed(2005)
shape = [N, C, H, W]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=True, sorted=True)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=True, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=True, sorted=True)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=True, sorted=True)

assert list(ttnn_topk_values.shape.with_tile_padding()) == [N, C, H, k]
assert list(ttnn_topk_indices.shape.with_tile_padding()) == [N, C, H, k]
desired_shape = [N, C, H, W]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape.with_tile_padding()) == desired_shape
assert list(ttnn_topk_indices.shape.with_tile_padding()) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)
Expand Down Expand Up @@ -64,21 +67,14 @@ def run_topk_test(N, C, H, W, k, dtype, device):
],
)
@pytest.mark.parametrize(
"N, C, H, W,",
"N, C, H, W, dim, k",
(
(1, 1, 32, 64),
(1, 1, 32, 8192),
(1, 1, 2048, 64),
(1, 1, 32, 32768),
(1, 1, 8192, 64),
(1, 1, 64, 64, 2, 32),
(1, 1, 32, 8192, 3, 64),
(1, 2048, 1, 64, 1, 32),
(1, 1, 32, 32768, 3, 64),
(128, 1, 1, 64, 0, 64),
),
)
@pytest.mark.parametrize(
"k",
[
32,
64,
],
)
def test_topk(N, C, H, W, k, dtype, device):
run_topk_test(N, C, H, W, k, dtype, device)
def test_topk(N, C, H, W, dim, k, dtype, device):
run_topk_test(N, C, H, W, k, dtype, dim, device)
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/prod/prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/halo.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/sliding_window.cpp
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ void TopK::validate_with_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
auto input_shape = input_tensors.at(0).get_legacy_shape();
TT_FATAL(input_shape.rank() == 4, "Input shape must be 4D, got {}", input_shape.rank());
TT_FATAL(this->dim == -1 || this->dim == 3, "Only the last dim is supported right now, got {}", this->dim);

TT_FATAL(
input_shape[-1] >= 64,
Expand Down
58 changes: 58 additions & 0 deletions ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "topk.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"

namespace ttnn::operations::reduction {

std::vector<Tensor> ExecuteTopK::invoke(
uint8_t queue_id,
const Tensor& input_tensor,
const uint16_t k,
const int8_t dim,
const bool largest,
const bool sorted,
const std::optional<MemoryConfig>& memory_config,
std::optional<std::tuple<Tensor, Tensor>> optional_output_tensors) {
const bool is_dim_last_idx = (dim == -1 || dim == 3);
auto input_memory_config = memory_config.value_or(input_tensor.memory_config());

// TODO : we may also have to address N-D tensor inputs in future
auto transform_tensor = [&](const Tensor& input_tensor, const int8_t dim1, const int8_t dim2 = -1) {
return ttnn::transpose(input_tensor, dim1, dim2, input_memory_config);
};

Tensor transformed_tensor = is_dim_last_idx ? input_tensor : transform_tensor(input_tensor, dim);

auto output_tensor_vec = operation::run(
TopK{k, -1, largest, sorted, input_memory_config},
{transformed_tensor},
{},
optional_output_tensors.has_value() ? tuple_to_vector_optional(optional_output_tensors.value())
: std::vector<std::optional<Tensor>>{},
queue_id);

if (is_dim_last_idx) {
return output_tensor_vec;
}

std::vector<Tensor> result_vec(2);
result_vec[0] = transform_tensor(output_tensor_vec[0], -1, dim);
result_vec[1] = transform_tensor(output_tensor_vec[1], -1, dim);
return result_vec;
}

auto ExecuteTopK::invoke(
const Tensor& input_tensor,
const uint16_t k,
const int8_t dim,
const bool largest,
const bool sorted,
const std::optional<MemoryConfig>& memory_config,
std::optional<std::tuple<Tensor, Tensor>> optional_output_tensors) {
return invoke(DefaultQueueId, input_tensor, k, dim, largest, sorted, memory_config, optional_output_tensors);
}

} // namespace ttnn::operations::reduction
18 changes: 4 additions & 14 deletions ttnn/cpp/ttnn/operations/reduction/topk/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,24 @@ namespace ttnn {
namespace operations::reduction {

struct ExecuteTopK {
static inline std::vector<Tensor> invoke(
static std::vector<Tensor> invoke(
uint8_t queue_id,
const Tensor& input_tensor,
const uint16_t k,
const int8_t dim,
const bool largest,
const bool sorted,
const std::optional<MemoryConfig>& memory_config,
std::optional<std::tuple<Tensor, Tensor>> optional_output_tensors = std::nullopt) {
return operation::run(
TopK{k, dim, largest, sorted, memory_config.value_or(input_tensor.memory_config())},
{input_tensor},
{},
optional_output_tensors.has_value() ? tuple_to_vector_optional(optional_output_tensors.value())
: std::vector<std::optional<Tensor>>{},
queue_id);
}
std::optional<std::tuple<Tensor, Tensor>> optional_output_tensors = std::nullopt);

static inline auto invoke(
static auto invoke(
const Tensor& input_tensor,
const uint16_t k,
const int8_t dim,
const bool largest,
const bool sorted,
const std::optional<MemoryConfig>& memory_config,
std::optional<std::tuple<Tensor, Tensor>> optional_output_tensors) {
return invoke(DefaultQueueId, input_tensor, k, dim, largest, sorted, memory_config, optional_output_tensors);
}
std::optional<std::tuple<Tensor, Tensor>> optional_output_tensors = std::nullopt);

static inline std::vector<Tensor> create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
Expand Down

0 comments on commit d0be909

Please sign in to comment.