Skip to content

Commit

Permalink
#12662: add keepdim and max fixes to reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
bbradelTT committed Dec 19, 2024
1 parent 485a18d commit ebf7dbd
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 33 deletions.
30 changes: 30 additions & 0 deletions tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,33 @@ def test_max_global(device, batch_size, h, w):
output_tensor = output_tensor[0, 0, 0]

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape_and_dim",
[
((1, 2, 3, 4), -1),
((2, 32, 64, 64), -4),
((2, 22, 37, 41), -4),
((2, 32, 64, 64), -3),
((2, 22, 37, 41), -3),
((2, 32, 64), -3),
((2, 22, 37), -3),
],
)
@pytest.mark.parametrize("keepdim", [True, False])
def test_max_dim(device, input_shape_and_dim, keepdim):
input_shape, max_dim = input_shape_and_dim

torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16)
torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.max(input_tensor, dim=max_dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)
110 changes: 77 additions & 33 deletions ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,62 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/data_movement/pad/pad.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
#include "ttnn/operations/reduction/generic/device/reduce_op.hpp"
#include "ttnn/operations/core/core.hpp"

// Some tensors are pre-padded with 0s. E.g. Those generated via from_torch.
// Therefore need to always pad tensors again. To do that, convert to row major,
// pad, and then convert back to tile layout.
// Limitations of pad require transpose, un-transpose, and then slicing to isolate values of interest.
// End result will be padded, and after reduce is done, will need to be sliced back.
ttnn::Tensor pad_tensor_with_value(const ttnn::Tensor& input_tensor, float pad_value) {
ttnn::Shape with_padding = input_tensor.get_shape().with_tile_padding();
ttnn::Tensor intermediate_tensor =
ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, input_tensor.device());
tt::tt_metal::Array4D padded_shape = {with_padding[0], with_padding[1], with_padding[2], with_padding[3]};
ttnn::Tensor padded_tensor =
ttnn::pad(intermediate_tensor, padded_shape, tt::tt_metal::Array4D({0, 0, 0, 0}), pad_value);
padded_tensor = ttnn::to_layout(padded_tensor, Layout::TILE, std::nullopt, std::nullopt, padded_tensor.device());
tt::log_debug(tt::LogOp, "max {} {} {}", padded_shape, pad_value, padded_tensor);
return padded_tensor;
}

// Pad tensor with values, reduce, and then slice back to un-padded size.
ttnn::Tensor reduce_with_padding(
ttnn::Tensor& input_tensor,
float pad_value,
tt::tt_metal::ReduceOpMath op,
const tt::tt_metal::ReduceOpDim reduce_op_dim,
float scalar,
const ttnn::MemoryConfig& memory_config,
const std::optional<ttnn::DeviceComputeKernelConfig>& compute_kernel_config) {
ttnn::Tensor padded_tensor = pad_tensor_with_value(input_tensor, pad_value);
ttnn::Tensor output_tensor = tt::tt_metal::reduce(
padded_tensor, op, reduce_op_dim, scalar, memory_config, std::nullopt, compute_kernel_config);
ttnn::Shape shape = input_tensor.get_shape();
std::array<uint32_t, 4> begins = {0, 0, 0, 0};
std::array<uint32_t, 4> ends = {shape[0], shape[1], shape[2], shape[3]};
std::array<uint32_t, 4> step = {1, 1, 1, 1};
if (reduce_op_dim == tt::tt_metal::ReduceOpDim::W) {
ends[3] = 1;
} else if (reduce_op_dim == tt::tt_metal::ReduceOpDim::H) {
ends[2] = 1;
} else if (reduce_op_dim == tt::tt_metal::ReduceOpDim::HW) {
ends[2] = 1;
ends[3] = 1;
} else {
TT_THROW("Unsupported reduce op dim {}", reduce_op_dim);
}

output_tensor = ttnn::slice(output_tensor, begins, ends, step);
return output_tensor;
}

namespace ttnn {
namespace operations::reduction {

Expand All @@ -22,10 +72,6 @@ static Tensor reduce_impl(
float scalar,
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
if (not keepdim) {
TT_THROW("keepdim=False is not supported");
}

auto input_shape = input_tensor_arg.get_shape();
auto rank = input_shape.size();
auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config());
Expand Down Expand Up @@ -58,41 +104,39 @@ static Tensor reduce_impl(
rank);
}

if (dim.size() == 1 && rank == 4) {
if (dim[0] == rank - 3) {
auto out_shape = input_tensor_arg.get_legacy_shape();
out_shape[1] = 1;

Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config);
output = reduce_impl<reduce_type>(output, 2, keepdim, memory_config, compute_kernel_config, scalar, false);
output = ttnn::transpose(output, 1, -2, memory_config);
return AutoFormat::format_output_tensor(output, out_shape, input_tensor_arg.device(), Layout::TILE);
} else if (dim[0] == 0) {
auto out_shape = input_tensor_arg.get_legacy_shape();
out_shape[0] = 1;

Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config);
output = reduce_impl<reduce_type>(output, 2, keepdim, memory_config, compute_kernel_config, scalar, false);
output = ttnn::transpose(output, 0, -2, memory_config);
return AutoFormat::format_output_tensor(output, out_shape, input_tensor_arg.device(), Layout::TILE);
}
}

std::sort(dim.begin(), dim.end());

ttnn::SmallVector<uint32_t> output_shape;
ttnn::SmallVector<uint32_t> padded_output_shape;
for (int axis = 0; axis < input_shape.size(); axis++) {
if (std::find(dim.begin(), dim.end(), axis) != dim.end()) {
if (keepdim) {
output_shape.push_back(1);
padded_output_shape.push_back(axis >= rank - 2 ? ttnn::TILE_SIZE : 1);
}
} else {
// Get the shape for the output tensor
output_shape.push_back(input_shape[axis]);
// Get the padded shape for the output tensor
padded_output_shape.push_back(input_shape.value[axis]);
}
}

if (dim.size() == 1 && (rank == 3 || rank == 4)) {
if (dim[0] == 1 && rank == 4) {
Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config);
output = reduce_impl<reduce_type>(
output, 2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 1, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
} else if (dim[0] == 0) {
Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config);
output = reduce_impl<reduce_type>(
output, -2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 0, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
}
}

Expand Down Expand Up @@ -154,22 +198,22 @@ static Tensor reduce_impl(
std::nullopt,
compute_kernel_config);
} else if constexpr (reduce_type == ReduceType::Max) {
output_tensor = tt::tt_metal::reduce(
output_tensor = reduce_with_padding(
input_tensor,
-std::numeric_limits<float>::infinity(),
tt::tt_metal::ReduceOpMath::MAX,
reduce_op_dim,
scalar,
memory_config,
std::nullopt,
compute_kernel_config);
} else if constexpr (reduce_type == ReduceType::Min) {
output_tensor = tt::tt_metal::reduce(
output_tensor = reduce_with_padding(
input_tensor,
std::numeric_limits<float>::infinity(),
tt::tt_metal::ReduceOpMath::MIN,
reduce_op_dim,
scalar,
memory_config,
std::nullopt,
compute_kernel_config);
} else if constexpr (reduce_type == ReduceType::Var or reduce_type == ReduceType::Std) {
auto mean_tensor = tt::tt_metal::reduce(
Expand Down Expand Up @@ -199,7 +243,7 @@ static Tensor reduce_impl(
}

if (reshape) {
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape, padded_output_shape});
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape});
}

return output_tensor;
Expand Down

0 comments on commit ebf7dbd

Please sign in to comment.