Skip to content

Commit

Permalink
support n-d tensor
Browse files Browse the repository at this point in the history
Signed-off-by: Amruth Sandhupatla <asandhupatla@tenstorrent.com>
  • Loading branch information
asandhupatlaTT committed Jan 22, 2025
1 parent 1f3d380 commit 9d4eef9
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 27 deletions.
102 changes: 102 additions & 0 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,105 @@ def test_mean_2d_tensor_dims(device, h, w, dim, keepdim):

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("dim1", [1])
@pytest.mark.parametrize("dim2", [1])
@pytest.mark.parametrize("dim3", [8])
@pytest.mark.parametrize("dim4", [256])
@pytest.mark.parametrize("dim5", [64])
# @pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) transpose cannot handle N-D tensor for all dims
@pytest.mark.parametrize("dim", [3, 4])
@pytest.mark.parametrize("k", [17, 32, 64])
@pytest.mark.parametrize("largest", [True])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_5d_topk(device, dim1, dim2, dim3, dim4, dim5, dim, k, largest, dtype):
torch.manual_seed(2005)
shape = [dim1, dim2, dim3, dim4, dim5]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=True)

desired_shape = [dim1, dim2, dim3, dim4, dim5]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)

if dtype == ttnn.bfloat8_b:
pcc_values = 0.99
else:
pcc_values = 1.0

# pcc is not a good measure for the raw indices
# if index 49 and index 8 are tied, the order of the indices can be different
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)


@pytest.mark.parametrize("dim1", [1])
@pytest.mark.parametrize("dim2", [1])
@pytest.mark.parametrize("dim3", [8])
@pytest.mark.parametrize("dim4", [1])
@pytest.mark.parametrize("dim5", [128])
@pytest.mark.parametrize("dim6", [64])
# @pytest.mark.parametrize("dim", [0, 1, 2, 3, 4, 5]) transpose cannot handle N-D tensor for all dims
@pytest.mark.parametrize("dim", [4, 5])
@pytest.mark.parametrize("k", [50, 64])
@pytest.mark.parametrize("largest", [True])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_6d_topk(device, dim1, dim2, dim3, dim4, dim5, dim6, dim, k, largest, dtype):
torch.manual_seed(2005)
shape = [dim1, dim2, dim3, dim4, dim5, dim6]
torch_dtype = torch.bfloat
16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=dim, largest=largest, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=dim, largest=largest, sorted=True)

desired_shape = [dim1, dim2, dim3, dim4, dim5, dim6]
desired_shape[dim] = k

assert list(ttnn_topk_values.shape) == desired_shape
assert list(ttnn_topk_indices.shape) == desired_shape

ttnn_torch_values = ttnn.to_torch(ttnn_topk_values)
ttnn_torch_indices = ttnn.to_torch(ttnn_topk_indices).to(torch.int64)

if dtype == ttnn.bfloat8_b:
pcc_values = 0.99
else:
pcc_values = 1.0

# pcc is not a good measure for the raw indices
# if index 49 and index 8 are tied, the order of the indices can be different
# but the values associated with the indices should be the same
# if index 7 and 8 are tied, but swapped, the pcc will be better than if index 49 and 8 are tied but swapped
# rounding may also cause more ties than expected
# the bigger we get, the tighter the distribution of the top K elements, so the pcc will be worse as stability/rounding will cause more ties
# use cosine similarity on the gathered indices as this will show the top elements are all about the same
ttnn_torch_gather_from_indices = torch.gather(input, dim, ttnn_torch_indices.to(torch.int64))
cosine = torch.nn.CosineSimilarity(dim=dim)
ttnn_torch_cosine = torch.mean(cosine(pyt_topk_values, ttnn_torch_gather_from_indices))

assert ttnn_torch_cosine > 0.99, "Cosine similarity between topk values and gather from indices is less than 0.99"
assert_with_pcc(pyt_topk_values, ttnn_torch_values, pcc_values)
88 changes: 61 additions & 27 deletions ttnn/cpp/ttnn/operations/reduction/topk/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,61 @@ uint16_t get_nearest_power_of_2(uint16_t k) {
return nearest_power_of_2;
}

inline Tensor perform_transpose(
const Tensor& input_tensor, const bool is_dim_last_idx, const int8_t dim1 = -1, const int8_t dim2 = -1) {
return is_dim_last_idx ? input_tensor : ttnn::transpose(input_tensor, dim1, dim2, input_tensor.memory_config());
}

inline Tensor transform_to_4d_tensor(const Tensor& input_tensor, const bool is_rank_le_4d) {
return is_rank_le_4d ? ttnn::unsqueeze_to_4D(input_tensor) : data_movement::squeeze_from_ND_to_4D(input_tensor);
}

// one stop for all transformations needed after executing top-k
// do we need seperate function for each case? revisit this later
std::vector<Tensor> post_topk_transform_tensor(
const Tensor& input_tensor,
std::vector<Tensor>& result,
const int8_t dim,
const bool is_dim_last_idx,
const uint16_t k,
const uint16_t adjusted_k,
const MemoryConfig& input_memory_config) {
TT_ASSERT(result[0].get_shape().rank() == 4, "Output shape rank must be 4");
TT_ASSERT(result[1].get_shape().rank() == 4, "Output shape rank must be 4");

auto input_shape = input_tensor.get_logical_shape();
const auto orig_rank = input_shape.rank();

// case 1 : K is not pow of 2
if (adjusted_k != k) {
auto output_shape = result[0].get_shape();
ttnn::SmallVector<uint32_t> step = {1, 1, 1, 1};
ttnn::SmallVector<uint32_t> start_index = {0, 0, 0, 0};
ttnn::SmallVector<uint32_t> end_index = {output_shape[0], output_shape[1], output_shape[2], k};
result[0] = ttnn::slice(result[0], start_index, end_index, step, input_memory_config);
result[1] = ttnn::slice(result[1], start_index, end_index, step, input_memory_config);
}

// case 2 : rank is not 4
if (orig_rank < 4) {
result[0] = ttnn::squeeze_from_4D(result[0], orig_rank);
result[1] = ttnn::squeeze_from_4D(result[1], orig_rank);
} else if (orig_rank > 4) {
ttnn::SmallVector<uint32_t> result_shape(input_shape.cbegin(), input_shape.cend());
result_shape[result_shape.size() - 1] = k;
result[0] = ttnn::reshape(result[0], ttnn::SimpleShape{result_shape});
result[1] = ttnn::reshape(result[1], ttnn::SimpleShape{result_shape});
}

// case 3 : dim is not last index
if (!is_dim_last_idx) {
result[0] = ttnn::transpose(result[0], dim, -1, input_tensor.memory_config());
result[1] = ttnn::transpose(result[1], dim, -1, input_tensor.memory_config());
}

return result;
}

std::vector<Tensor> ExecuteTopK::invoke(
uint8_t queue_id,
const Tensor& input_tensor,
Expand All @@ -33,16 +88,10 @@ std::vector<Tensor> ExecuteTopK::invoke(

// K may not be power of 2
uint16_t adjusted_k = get_nearest_power_of_2(k);
// TODO : we may also have to address N-D tensor inputs in future
// tensor transformed_input_tensor = is_rank_le_4d ? ttnn::unsqueeze_to_4D(input_tensor_arg) :
// data_movement::squeeze_from_ND_to_4D(input_tensor);

// support any dim value
auto transform_tensor = [&](const Tensor& input_tensor, const int8_t dim1, const int8_t dim2 = -1) {
return ttnn::transpose(input_tensor, dim1, dim2, input_memory_config);
};

Tensor transformed_tensor = is_dim_last_idx ? input_tensor : transform_tensor(input_tensor, dim);
// if dim is not last dimension, transpose it
Tensor transposed_tensor = perform_transpose(input_tensor, is_dim_last_idx, dim, -1);
// if input is not 4d, convert it to 4d
Tensor transformed_tensor = transform_to_4d_tensor(transposed_tensor, is_rank_le_4d);

auto output_tensor_vec = operation::run(
TopK{adjusted_k, -1, largest, sorted, input_memory_config},
Expand All @@ -52,23 +101,8 @@ std::vector<Tensor> ExecuteTopK::invoke(
: std::vector<std::optional<Tensor>>{},
queue_id);

if (adjusted_k != k) {
auto output_shape = output_tensor_vec[0].get_shape();
ttnn::SmallVector<uint32_t> step = {1, 1, 1, 1};
ttnn::SmallVector<uint32_t> start_index = {0, 0, 0, 0};
ttnn::SmallVector<uint32_t> end_index = {output_shape[0], output_shape[1], output_shape[2], k};
output_tensor_vec[0] = ttnn::slice(output_tensor_vec[0], start_index, end_index, step, input_memory_config);
output_tensor_vec[1] = ttnn::slice(output_tensor_vec[1], start_index, end_index, step, input_memory_config);
}

if (is_dim_last_idx) {
return output_tensor_vec;
}

std::vector<Tensor> result_vec(2);
result_vec[0] = transform_tensor(output_tensor_vec[0], -1, dim);
result_vec[1] = transform_tensor(output_tensor_vec[1], -1, dim);
return result_vec;
return post_topk_transform_tensor(
transposed_tensor, output_tensor_vec, dim, is_dim_last_idx, k, adjusted_k, input_memory_config);
}

auto ExecuteTopK::invoke(
Expand Down

0 comments on commit 9d4eef9

Please sign in to comment.